Skip to content

Commit 4ad8e17

Browse files
committed
fix(client): validate api connection before running batch methods
1 parent 15574c7 commit 4ad8e17

2 files changed

Lines changed: 183 additions & 8 deletions

File tree

openfga_sdk/client/openfga_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ def get_authorization_model_id(self):
159159
"""
160160
return self._client_configuration.authorization_model_id
161161

162+
async def _check_valid_api_connection(self, options: dict[str, int | str]):
163+
"""
164+
Checks that a connection with the given configuration can be established
165+
"""
166+
authorization_model_id = self._get_authorization_model_id(options)
167+
if authorization_model_id is not None and authorization_model_id != "":
168+
await self.read_authorization_model(options)
169+
else:
170+
await self.read_latest_authorization_model(options)
171+
162172
#################
163173
# Stores
164174
#################
@@ -426,6 +436,9 @@ async def writes(self, body: ClientWriteRequest, options: dict[str, str]):
426436
return results
427437

428438
options = set_heading_if_not_set(options, CLIENT_BULK_REQUEST_ID_HEADER, str(uuid.uuid4()))
439+
# TODO: this should be run in parallel
440+
await self._check_valid_api_connection(options)
441+
429442
# otherwise, it is not a transaction and it is a batch write requests
430443
writes_response = None
431444
if body.writes:
@@ -520,6 +533,9 @@ async def batch_check(self, body: List[CheckRequestBody], options: dict[str, str
520533
options = set_heading_if_not_set(options, CLIENT_METHOD_HEADER, "BatchCheck")
521534
options = set_heading_if_not_set(options, CLIENT_BULK_REQUEST_ID_HEADER, str(uuid.uuid4()))
522535

536+
# TODO: this should be run in parallel
537+
await self._check_valid_api_connection(options)
538+
523539
max_parallel_requests = 10
524540
if options is not None and "max_parallel_requests" in options:
525541
max_parallel_requests = options["max_parallel_requests"]

test/test_openfga_client.py

Lines changed: 167 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from openfga_sdk.client.read_changes_body import ReadChangesBody
3131
from openfga_sdk.client.single_write_response import SingleWriteResponse
3232
from openfga_sdk.client.write_transaction import WriteTransaction
33-
from openfga_sdk.exceptions import ValidationException, FgaValidationException
33+
from openfga_sdk.exceptions import ValidationException, FgaValidationException, UnauthorizedException
3434
from openfga_sdk.models.assertion import Assertion
3535
from openfga_sdk.models.authorization_model import AuthorizationModel
3636
from openfga_sdk.models.check_response import CheckResponse
@@ -820,6 +820,7 @@ async def test_write_batch(self, mock_request):
820820
mock_response('{}', 200),
821821
mock_response('{}', 200),
822822
mock_response('{}', 200),
823+
mock_response('{}', 200),
823824
]
824825
configuration = self.configuration
825826
configuration.store_id = store_id
@@ -881,7 +882,15 @@ async def test_write_batch(self, mock_request):
881882
error=None)
882883
]
883884
)
884-
self.assertEqual(mock_request.call_count, 3)
885+
self.assertEqual(mock_request.call_count, 4)
886+
mock_request.assert_any_call(
887+
'GET',
888+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01G5JAVJ41T49E9TT3SKVS7X1J',
889+
headers=ANY,
890+
query_params=[],
891+
_preload_content=ANY,
892+
_request_timeout=None
893+
)
885894
mock_request.assert_any_call(
886895
'POST',
887896
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/write',
@@ -926,6 +935,7 @@ async def test_write_batch_min_parallel(self, mock_request):
926935
mock_response('{}', 200),
927936
mock_response('{}', 200),
928937
mock_response('{}', 200),
938+
mock_response('{}', 200),
929939
]
930940
configuration = self.configuration
931941
configuration.store_id = store_id
@@ -986,7 +996,15 @@ async def test_write_batch_min_parallel(self, mock_request):
986996
error=None)
987997
]
988998
)
989-
self.assertEqual(mock_request.call_count, 3)
999+
self.assertEqual(mock_request.call_count, 4)
1000+
mock_request.assert_any_call(
1001+
'GET',
1002+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01G5JAVJ41T49E9TT3SKVS7X1J',
1003+
headers=ANY,
1004+
query_params=[],
1005+
_preload_content=ANY,
1006+
_request_timeout=None
1007+
)
9901008
mock_request.assert_any_call(
9911009
'POST',
9921010
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/write',
@@ -1030,6 +1048,7 @@ async def test_write_batch_larger_chunk(self, mock_request):
10301048
mock_request.side_effect = [
10311049
mock_response('{}', 200),
10321050
mock_response('{}', 200),
1051+
mock_response('{}', 200),
10331052
]
10341053
configuration = self.configuration
10351054
configuration.store_id = store_id
@@ -1090,7 +1109,15 @@ async def test_write_batch_larger_chunk(self, mock_request):
10901109
error=None)
10911110
]
10921111
)
1093-
self.assertEqual(mock_request.call_count, 2)
1112+
self.assertEqual(mock_request.call_count, 3)
1113+
mock_request.assert_any_call(
1114+
'GET',
1115+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01G5JAVJ41T49E9TT3SKVS7X1J',
1116+
headers=ANY,
1117+
query_params=[],
1118+
_preload_content=ANY,
1119+
_request_timeout=None
1120+
)
10941121
mock_request.assert_any_call(
10951122
'POST',
10961123
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/write',
@@ -1128,6 +1155,7 @@ async def test_write_batch_failed(self, mock_request):
11281155
'''
11291156

11301157
mock_request.side_effect = [
1158+
mock_response('{}', 200),
11311159
mock_response('{}', 200),
11321160
ValidationException(http_resp=http_mock_response(response_body, 400)),
11331161
mock_response('{}', 200),
@@ -1193,7 +1221,15 @@ async def test_write_batch_failed(self, mock_request):
11931221
),
11941222
success=True,
11951223
error=None))
1196-
self.assertEqual(mock_request.call_count, 3)
1224+
self.assertEqual(mock_request.call_count, 4)
1225+
mock_request.assert_any_call(
1226+
'GET',
1227+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01G5JAVJ41T49E9TT3SKVS7X1J',
1228+
headers=ANY,
1229+
query_params=[],
1230+
_preload_content=ANY,
1231+
_request_timeout=None
1232+
)
11971233
mock_request.assert_any_call(
11981234
'POST',
11991235
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/write',
@@ -1236,6 +1272,7 @@ async def test_delete_batch(self, mock_request):
12361272
"""
12371273
mock_request.side_effect = [
12381274
mock_response('{}', 200),
1275+
mock_response('{}', 200),
12391276
]
12401277
configuration = self.configuration
12411278
configuration.store_id = store_id
@@ -1257,7 +1294,15 @@ async def test_delete_batch(self, mock_request):
12571294
options={"authorization_model_id": "01G5JAVJ41T49E9TT3SKVS7X1J",
12581295
"transaction": transaction}
12591296
)
1260-
mock_request.assert_called_once_with(
1297+
mock_request.assert_any_call(
1298+
'GET',
1299+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01G5JAVJ41T49E9TT3SKVS7X1J',
1300+
headers=ANY,
1301+
query_params=[],
1302+
_preload_content=ANY,
1303+
_request_timeout=None
1304+
)
1305+
mock_request.assert_any_call(
12611306
'POST',
12621307
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/write',
12631308
headers=ANY,
@@ -1357,6 +1402,49 @@ async def test_delete_tuples(self, mock_request):
13571402
_request_timeout=None
13581403
)
13591404

1405+
@patch.object(rest.RESTClientObject, 'request')
1406+
async def test_write_batch_unauthorized(self, mock_request):
1407+
"""Test case for write with 401 response
1408+
"""
1409+
1410+
mock_request.side_effect = UnauthorizedException(
1411+
http_resp=http_mock_response('{}', 401)
1412+
)
1413+
configuration = self.configuration
1414+
configuration.store_id = store_id
1415+
async with OpenFgaClient(configuration) as api_client:
1416+
with self.assertRaises(UnauthorizedException) as api_exception:
1417+
body = ClientWriteRequest(
1418+
writes=[
1419+
ClientTuple(
1420+
object="document:2021-budget",
1421+
relation="reader",
1422+
user="user:81684243-9356-4421-8fbf-a4f8d36aa31b",
1423+
)
1424+
],
1425+
)
1426+
transaction = WriteTransaction(
1427+
disabled=True, max_per_chunk=1, max_parallel_requests=10)
1428+
await api_client.writes(
1429+
body,
1430+
options={"authorization_model_id": "01G5JAVJ41T49E9TT3SKVS7X1J",
1431+
"transaction": transaction}
1432+
)
1433+
1434+
self.assertIsInstance(api_exception.exception, UnauthorizedException)
1435+
mock_request.assert_called()
1436+
self.assertEqual(mock_request.call_count, 1)
1437+
1438+
mock_request.assert_called_once_with(
1439+
'GET',
1440+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01G5JAVJ41T49E9TT3SKVS7X1J',
1441+
headers=ANY,
1442+
query_params=[],
1443+
_preload_content=ANY,
1444+
_request_timeout=None
1445+
)
1446+
await api_client.close()
1447+
13601448
@patch.object(rest.RESTClientObject, 'request')
13611449
async def test_check(self, mock_request):
13621450
"""Test case for check
@@ -1443,7 +1531,10 @@ async def test_batch_check_single_request(self, mock_request):
14431531

14441532
# First, mock the response
14451533
response_body = '{"allowed": true, "resolution": "1234"}'
1446-
mock_request.return_value = mock_response(response_body, 200)
1534+
mock_request.side_effect = [
1535+
mock_response('{}', 200),
1536+
mock_response(response_body, 200),
1537+
]
14471538
body = CheckRequestBody(
14481539
object="document:2021-budget",
14491540
relation="reader",
@@ -1462,7 +1553,15 @@ async def test_batch_check_single_request(self, mock_request):
14621553
self.assertTrue(api_response[0].allowed)
14631554
self.assertEqual(api_response[0].request, body)
14641555
# Make sure the API was called with the right data
1465-
mock_request.assert_called_once_with(
1556+
mock_request.assert_any_call(
1557+
'GET',
1558+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01GXSA8YR785C4FYS3C0RTG7B1',
1559+
headers=ANY,
1560+
query_params=[],
1561+
_preload_content=ANY,
1562+
_request_timeout=None
1563+
)
1564+
mock_request.assert_any_call(
14661565
'POST',
14671566
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/check',
14681567
headers=ANY,
@@ -1484,6 +1583,7 @@ async def test_batch_check_multiple_request(self, mock_request):
14841583

14851584
# First, mock the response
14861585
mock_request.side_effect = [
1586+
mock_response('{}', 200),
14871587
mock_response('{"allowed": true, "resolution": "1234"}', 200),
14881588
mock_response('{"allowed": false, "resolution": "1234"}', 200),
14891589
mock_response('{"allowed": true, "resolution": "1234"}', 200),
@@ -1523,6 +1623,14 @@ async def test_batch_check_multiple_request(self, mock_request):
15231623
self.assertTrue(api_response[2].allowed)
15241624
self.assertEqual(api_response[2].request, body3)
15251625
# Make sure the API was called with the right data
1626+
mock_request.assert_any_call(
1627+
'GET',
1628+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01GXSA8YR785C4FYS3C0RTG7B1',
1629+
headers=ANY,
1630+
query_params=[],
1631+
_preload_content=ANY,
1632+
_request_timeout=None
1633+
)
15261634
mock_request.assert_any_call(
15271635
'POST',
15281636
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/check',
@@ -1573,6 +1681,7 @@ async def test_batch_check_multiple_request_fail(self, mock_request):
15731681

15741682
# First, mock the response
15751683
mock_request.side_effect = [
1684+
mock_response('{}', 200),
15761685
mock_response('{"allowed": true, "resolution": "1234"}', 200),
15771686
ValidationException(http_resp=http_mock_response(response_body, 400)),
15781687
mock_response('{"allowed": false, "resolution": "1234"}', 200),
@@ -1614,6 +1723,14 @@ async def test_batch_check_multiple_request_fail(self, mock_request):
16141723
self.assertFalse(api_response[2].allowed)
16151724
self.assertEqual(api_response[2].request, body3)
16161725
# Make sure the API was called with the right data
1726+
mock_request.assert_any_call(
1727+
'GET',
1728+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01GXSA8YR785C4FYS3C0RTG7B1',
1729+
headers=ANY,
1730+
query_params=[],
1731+
_preload_content=ANY,
1732+
_request_timeout=None
1733+
)
16171734
mock_request.assert_any_call(
16181735
'POST',
16191736
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/check',
@@ -1785,6 +1902,7 @@ async def test_list_relations(self, mock_request):
17851902

17861903
# First, mock the response
17871904
mock_request.side_effect = [
1905+
mock_response('{}', 200),
17881906
mock_response('{"allowed": true, "resolution": "1234"}', 200),
17891907
mock_response('{"allowed": false, "resolution": "1234"}', 200),
17901908
mock_response('{"allowed": true, "resolution": "1234"}', 200),
@@ -1801,6 +1919,14 @@ async def test_list_relations(self, mock_request):
18011919
self.assertEqual(api_response, ["reader", "viewer"])
18021920

18031921
# Make sure the API was called with the right data
1922+
mock_request.assert_any_call(
1923+
'GET',
1924+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01GXSA8YR785C4FYS3C0RTG7B1',
1925+
headers=ANY,
1926+
query_params=[],
1927+
_preload_content=ANY,
1928+
_request_timeout=None
1929+
)
18041930
mock_request.assert_any_call(
18051931
'POST',
18061932
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/check',
@@ -1836,6 +1962,39 @@ async def test_list_relations(self, mock_request):
18361962
)
18371963
await api_client.close()
18381964

1965+
@patch.object(rest.RESTClientObject, 'request')
1966+
async def test_list_relations_unauthorized(self, mock_request):
1967+
"""Test case for list relations with 401 response
1968+
"""
1969+
1970+
mock_request.side_effect = UnauthorizedException(
1971+
http_resp=http_mock_response('{}', 401)
1972+
)
1973+
configuration = self.configuration
1974+
configuration.store_id = store_id
1975+
async with OpenFgaClient(configuration) as api_client:
1976+
with self.assertRaises(UnauthorizedException) as api_exception:
1977+
await api_client.list_relations(
1978+
body=ListRelationsRequestBody(user="user:81684243-9356-4421-8fbf-a4f8d36aa31b",
1979+
relations=["reader", "owner", "viewer"],
1980+
object="document:2021-budget"),
1981+
options={"authorization_model_id": "01GXSA8YR785C4FYS3C0RTG7B1"}
1982+
)
1983+
1984+
self.assertIsInstance(api_exception.exception, UnauthorizedException)
1985+
mock_request.assert_called()
1986+
self.assertEqual(mock_request.call_count, 1)
1987+
1988+
mock_request.assert_called_once_with(
1989+
'GET',
1990+
'http://api.fga.example/stores/01YCP46JKYM8FJCQ37NMBYHE5X/authorization-models/01GXSA8YR785C4FYS3C0RTG7B1',
1991+
headers=ANY,
1992+
query_params=[],
1993+
_preload_content=ANY,
1994+
_request_timeout=None
1995+
)
1996+
await api_client.close()
1997+
18391998
@patch.object(rest.RESTClientObject, 'request')
18401999
async def test_read_assertions(self, mock_request):
18412000
"""Test case for read assertions

0 commit comments

Comments
 (0)