|
17 | 17 | from openfga_sdk.api_client import ApiClient |
18 | 18 | from openfga_sdk.client.configuration import ClientConfiguration |
19 | 19 | from openfga_sdk.client.models.assertion import ClientAssertion |
20 | | -from openfga_sdk.client.models.batch_check_response import BatchCheckResponse |
| 20 | +from openfga_sdk.client.models.batch_check_item import ( |
| 21 | + ClientBatchCheckItem, |
| 22 | + construct_batch_item, |
| 23 | +) |
| 24 | +from openfga_sdk.client.models.batch_check_request import ClientBatchCheckRequest |
| 25 | +from openfga_sdk.client.models.batch_check_response import ( |
| 26 | + ClientBatchCheckResponse, |
| 27 | +) |
| 28 | +from openfga_sdk.client.models.batch_check_single_response import ( |
| 29 | + ClientBatchCheckSingleResponse, |
| 30 | +) |
21 | 31 | from openfga_sdk.client.models.check_request import ( |
22 | 32 | ClientCheckRequest, |
23 | 33 | construct_check_request, |
24 | 34 | ) |
25 | | -from openfga_sdk.client.models.client_batch_check_response import ClientBatchCheckClientResponse |
| 35 | +from openfga_sdk.client.models.client_batch_check_response import ( |
| 36 | + ClientBatchCheckClientResponse, |
| 37 | +) |
26 | 38 | from openfga_sdk.client.models.expand_request import ClientExpandRequest |
27 | 39 | from openfga_sdk.client.models.list_objects_request import ClientListObjectsRequest |
28 | 40 | from openfga_sdk.client.models.list_relations_request import ClientListRelationsRequest |
|
41 | 53 | UnauthorizedException, |
42 | 54 | ) |
43 | 55 | from openfga_sdk.models.assertion import Assertion |
| 56 | +from openfga_sdk.models.batch_check_request import BatchCheckRequest |
44 | 57 | from openfga_sdk.models.check_request import CheckRequest |
45 | 58 | from openfga_sdk.models.contextual_tuple_keys import ContextualTupleKeys |
46 | 59 | from openfga_sdk.models.create_store_request import CreateStoreRequest |
@@ -637,6 +650,118 @@ async def client_batch_check( |
637 | 650 |
|
638 | 651 | return batch_check_response |
639 | 652 |
|
| 653 | + async def _single_batch_check( |
| 654 | + self, |
| 655 | + body: BatchCheckRequest, |
| 656 | + semaphore: asyncio.Semaphore, |
| 657 | + options: dict[str, str] = None, |
| 658 | + ): |
| 659 | + """ |
| 660 | + Run a single BatchCheck request |
| 661 | + :param body - list[ClientCheckRequest] defining check request |
| 662 | + :param authorization_model_id(options) - Overrides the authorization model id in the configuration |
| 663 | + """ |
| 664 | + await semaphore.acquire() |
| 665 | + try: |
| 666 | + kwargs = options_to_kwargs(options) |
| 667 | + api_response = await self._api.batch_check(body, **kwargs) |
| 668 | + return api_response |
| 669 | + except Exception as err: |
| 670 | + raise err |
| 671 | + finally: |
| 672 | + semaphore.release() |
| 673 | + |
| 674 | + async def batch_check(self, body: ClientBatchCheckRequest, options=None): |
| 675 | + """ |
| 676 | + Run a batchcheck request |
| 677 | + :param body - BatchCheck request |
| 678 | + :param authorization_model_id(options) - Overrides the authorization model id in the configuration |
| 679 | + :param max_parallel_requests(options) - Max number of requests to issue in parallel. Defaults to 10 |
| 680 | + :param max_batch_size(options) - Max number of checks to include in a request. Defaults to 50 |
| 681 | + :param header(options) - Custom headers to send alongside the request |
| 682 | + :param retryParams(options) - Override the retry parameters for this request |
| 683 | + :param retryParams.maxRetry(options) - Override the max number of retries on each API request |
| 684 | + :param retryParams.minWaitInMs(options) - Override the minimum wait before a retry is initiated |
| 685 | + """ |
| 686 | + options = set_heading_if_not_set( |
| 687 | + options, CLIENT_BULK_REQUEST_ID_HEADER, str(uuid.uuid4()) |
| 688 | + ) |
| 689 | + |
| 690 | + max_parallel_requests = 10 |
| 691 | + if options is not None and "max_parallel_requests" in options: |
| 692 | + if ( |
| 693 | + isinstance(options["max_parallel_requests"], str) |
| 694 | + and options["max_parallel_requests"].isdigit() |
| 695 | + ): |
| 696 | + max_parallel_requests = int(options["max_parallel_requests"]) |
| 697 | + elif isinstance(options["max_parallel_requests"], int): |
| 698 | + max_parallel_requests = options["max_parallel_requests"] |
| 699 | + |
| 700 | + max_batch_size = 50 |
| 701 | + if options is not None and "max_batch_size" in options: |
| 702 | + if ( |
| 703 | + isinstance(options["max_batch_size"], str) |
| 704 | + and options["max_batch_size"].isdigit() |
| 705 | + ): |
| 706 | + max_batch_size = int(options["max_batch_size"]) |
| 707 | + elif isinstance(options["max_batch_size"], int): |
| 708 | + max_batch_size = options["max_batch_size"] |
| 709 | + |
| 710 | + check_to_id: dict[str, ClientBatchCheckItem] = {} |
| 711 | + |
| 712 | + def track_and_transform(checks): |
| 713 | + transformed = [] |
| 714 | + for check in checks: |
| 715 | + if check.correlation_id is None: |
| 716 | + check.correlation_id = str(uuid.uuid4()) |
| 717 | + |
| 718 | + if check.correlation_id in check_to_id: |
| 719 | + raise FgaValidationException("Duplicate correlation_id provided") |
| 720 | + |
| 721 | + check_to_id[check.correlation_id] = check |
| 722 | + |
| 723 | + transformed.append(construct_batch_item(check)) |
| 724 | + return transformed |
| 725 | + |
| 726 | + checks = [ |
| 727 | + track_and_transform( |
| 728 | + body.checks[i * max_batch_size : (i + 1) * max_batch_size] |
| 729 | + ) |
| 730 | + for i in range((len(body.checks) + max_batch_size - 1) // max_batch_size) |
| 731 | + ] |
| 732 | + |
| 733 | + result = [] |
| 734 | + sem = asyncio.Semaphore(max_parallel_requests) |
| 735 | + |
| 736 | + def map_response(id, result): |
| 737 | + check = check_to_id[id] |
| 738 | + return ClientBatchCheckSingleResponse( |
| 739 | + allowed=result.allowed, |
| 740 | + request=check, |
| 741 | + correlation_id=id, |
| 742 | + error=result.error, |
| 743 | + ) |
| 744 | + |
| 745 | + async def coro(checks): |
| 746 | + res = await self._single_batch_check( |
| 747 | + BatchCheckRequest( |
| 748 | + checks=checks, |
| 749 | + authorization_model_id=self._get_authorization_model_id(options), |
| 750 | + consistency=self._get_consistency(options), |
| 751 | + ), |
| 752 | + sem, |
| 753 | + options, |
| 754 | + ) |
| 755 | + |
| 756 | + result.extend( |
| 757 | + [map_response(c_id, c_result) for c_id, c_result in res.result.items()] |
| 758 | + ) |
| 759 | + |
| 760 | + batch_check_coros = [coro(request) for request in checks] |
| 761 | + await asyncio.gather(*batch_check_coros) |
| 762 | + |
| 763 | + return ClientBatchCheckResponse(result) |
| 764 | + |
640 | 765 | async def expand(self, body: ClientExpandRequest, options: dict[str, str] = None): |
641 | 766 | """ |
642 | 767 | Run expand request |
|
0 commit comments