diff --git a/openfga_sdk/rest.py b/openfga_sdk/rest.py index 4621e27..fe775dc 100644 --- a/openfga_sdk/rest.py +++ b/openfga_sdk/rest.py @@ -1,7 +1,6 @@ import io import json import logging -import re import ssl import urllib @@ -221,7 +220,7 @@ async def build_request( args["url"] = f"{url}?{encoded_qs}" if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]: - if re.search("json", headers["Content-Type"], re.IGNORECASE): + if "json" in headers["Content-Type"].lower(): if body is not None: body = json.dumps(body) args["data"] = body @@ -397,11 +396,15 @@ async def stream( if isinstance(response, aiohttp.ClientResponse): logger.debug("response body: %s", buffer.decode("utf-8")) - # Handle any HTTP errors that may have occurred - await self.handle_response_exception(response) - - # Release the response object (required!) - response.release() + try: + # Handle any HTTP errors that may have occurred + await self.handle_response_exception(response) + finally: + # Release the response object back to the connection pool. + # This must always run, even if handle_response_exception raises, + # to avoid leaking the connection (preload_content=False means + # the connection is not auto-released). + response.release() async def request( self, diff --git a/openfga_sdk/sync/api_client.py b/openfga_sdk/sync/api_client.py index 7990e25..c688650 100644 --- a/openfga_sdk/sync/api_client.py +++ b/openfga_sdk/sync/api_client.py @@ -109,6 +109,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def close(self): + self.rest_client.close() if self._pool: self._pool.close() self._pool.join() diff --git a/openfga_sdk/sync/rest.py b/openfga_sdk/sync/rest.py index 63d90a4..90fd06a 100644 --- a/openfga_sdk/sync/rest.py +++ b/openfga_sdk/sync/rest.py @@ -1,7 +1,6 @@ import io import json import logging -import re import ssl import urllib @@ -279,7 +278,7 @@ def build_request( # Handle body/post_params for methods that send payloads if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]: - if re.search("json", headers["Content-Type"], re.IGNORECASE): + if "json" in headers["Content-Type"].lower(): if body is not None: body = json.dumps(body) args["body"] = body @@ -437,11 +436,15 @@ def stream( except json.JSONDecodeError: logger.debug("Incomplete leftover data at end of stream.") - # Handle any HTTP errors that may have occurred - self.handle_response_exception(response) - - # Release the response object (required!) - response.release_conn() + try: + # Handle any HTTP errors that may have occurred + self.handle_response_exception(response) + finally: + # Release the response object back to the connection pool. + # This must always run, even if handle_response_exception raises, + # to avoid leaking the connection (preload_content=False means + # urllib3 does not auto-release). + response.release_conn() def request( self, @@ -494,10 +497,13 @@ def request( # Log the response body logger.debug("response body: %s", wrapped_response.data.decode("utf-8")) - # Handle any errors that may have occurred - self.handle_response_exception(raw_response) - - # Release the connection back to the pool - self.close() + # Handle any errors that may have occurred. If an exception is raised, + # ensure the underlying response is closed so the connection is not + # leaked from the pool. + try: + self.handle_response_exception(raw_response) + except Exception: + raw_response.close() + raise return wrapped_response or raw_response diff --git a/openfga_sdk/telemetry/attributes.py b/openfga_sdk/telemetry/attributes.py index ddda43f..02451c1 100644 --- a/openfga_sdk/telemetry/attributes.py +++ b/openfga_sdk/telemetry/attributes.py @@ -241,8 +241,9 @@ def fromRequest( _attributes[TelemetryAttributes.http_request_method] = http_method if url is not None: - _hostname = urllib.parse.urlparse(url).hostname - _scheme = urllib.parse.urlparse(url).scheme + _parsed_url = urllib.parse.urlparse(url) + _hostname = _parsed_url.hostname + _scheme = _parsed_url.scheme if type(_hostname) is str: _attributes[TelemetryAttributes.http_host] = _hostname diff --git a/openfga_sdk/validation.py b/openfga_sdk/validation.py index c9f9614..2be4b8a 100644 --- a/openfga_sdk/validation.py +++ b/openfga_sdk/validation.py @@ -1,11 +1,13 @@ import re +_ULID_REGEX = re.compile("^[0-7][0-9A-HJKMNP-TV-Z]{25}$") + + def is_well_formed_ulid_string(ulid): - regex = re.compile("^[0-7][0-9A-HJKMNP-TV-Z]{25}$") if not isinstance(ulid, str): return False - match = regex.match(ulid) + match = _ULID_REGEX.match(ulid) if match is None: return False return True diff --git a/test/rest_test.py b/test/rest_test.py index 949828e..e6a712d 100644 --- a/test/rest_test.py +++ b/test/rest_test.py @@ -428,3 +428,98 @@ async def iter_chunks(self): client.handle_response_exception.assert_awaited_once() mock_response.release.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_releases_conn_on_error_status(): + """Ensure release() is called even when handle_response_exception raises, + so the connection is returned to the pool and not leaked.""" + mock_config = MagicMock() + mock_config.ssl_ca_cert = None + mock_config.cert_file = None + mock_config.key_file = None + mock_config.verify_ssl = True + mock_config.connection_pool_maxsize = 4 + mock_config.proxy = None + mock_config.proxy_headers = None + mock_config.timeout_millisec = 5000 + + client = RESTClientObject(configuration=mock_config) + mock_session = MagicMock() + client.pool_manager = mock_session + + class FakeContent: + async def iter_chunks(self): + yield (b'{"ok":true}\n', None) + + mock_response = MagicMock() + mock_response.status = 500 + mock_response.reason = "Internal Server Error" + mock_response.data = None + mock_response.content = FakeContent() + + mock_context_manager = AsyncMock() + mock_context_manager.__aenter__.return_value = mock_response + mock_context_manager.__aexit__.return_value = None + + mock_session.request.return_value = mock_context_manager + + # Make handle_response_exception raise an exception + client.handle_response_exception = AsyncMock( + side_effect=ServiceException(status=500, reason="Internal Server Error") + ) + client.close = AsyncMock() + + results = [] + with pytest.raises(ServiceException): + async for item in client.stream("GET", "http://example.com"): + results.append(item) + + # The critical assertion: release() must be called even though + # handle_response_exception raised ServiceException + mock_response.release.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_releases_conn_on_success(): + """Ensure release() is called on successful stream completion.""" + mock_config = MagicMock() + mock_config.ssl_ca_cert = None + mock_config.cert_file = None + mock_config.key_file = None + mock_config.verify_ssl = True + mock_config.connection_pool_maxsize = 4 + mock_config.proxy = None + mock_config.proxy_headers = None + mock_config.timeout_millisec = 5000 + + client = RESTClientObject(configuration=mock_config) + mock_session = MagicMock() + client.pool_manager = mock_session + + class FakeContent: + async def iter_chunks(self): + yield (b'{"ok":true}\n', None) + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.data = None + mock_response.content = FakeContent() + + mock_context_manager = AsyncMock() + mock_context_manager.__aenter__.return_value = mock_response + mock_context_manager.__aexit__.return_value = None + + mock_session.request.return_value = mock_context_manager + + client.handle_response_exception = AsyncMock() + client.close = AsyncMock() + + results = [] + async for item in client.stream("GET", "http://example.com"): + results.append(item) + + assert results == [{"ok": True}] + client.handle_response_exception.assert_awaited_once() + mock_response.release.assert_called_once() diff --git a/test/sync/rest_test.py b/test/sync/rest_test.py index 2f4c109..c4f2a22 100644 --- a/test/sync/rest_test.py +++ b/test/sync/rest_test.py @@ -522,6 +522,96 @@ def release_conn(self): mock_pool_manager.request.assert_called_once() +def test_stream_releases_conn_on_error_status(): + """Ensure release_conn() is called even when handle_response_exception raises, + so the connection is returned to the pool and not leaked.""" + mock_config = MagicMock( + spec=[ + "verify_ssl", + "ssl_ca_cert", + "cert_file", + "key_file", + "assert_hostname", + "retries", + "socket_options", + "connection_pool_maxsize", + "timeout_millisec", + "proxy", + "proxy_headers", + ] + ) + mock_config.ssl_ca_cert = None + mock_config.cert_file = None + mock_config.key_file = None + mock_config.verify_ssl = True + mock_config.connection_pool_maxsize = 4 + mock_config.timeout_millisec = 5000 + mock_config.proxy = None + mock_config.proxy_headers = None + + client = RESTClientObject(configuration=mock_config) + mock_pool_manager = MagicMock() + client.pool_manager = mock_pool_manager + + mock_response = MagicMock() + mock_response.status = 500 + mock_response.reason = "Internal Server Error" + mock_response.stream.return_value = iter([]) # empty stream, no chunks + + mock_pool_manager.request.return_value = mock_response + + with pytest.raises(ServiceException): + # Must consume the generator to trigger the finally block + list(client.stream("GET", "http://example.com")) + + # The critical assertion: release_conn() must be called even though + # handle_response_exception raised ServiceException + mock_response.release_conn.assert_called_once() + + +def test_stream_releases_conn_on_success(): + """Ensure release_conn() is called on successful stream completion.""" + mock_config = MagicMock( + spec=[ + "verify_ssl", + "ssl_ca_cert", + "cert_file", + "key_file", + "assert_hostname", + "retries", + "socket_options", + "connection_pool_maxsize", + "timeout_millisec", + "proxy", + "proxy_headers", + ] + ) + mock_config.ssl_ca_cert = None + mock_config.cert_file = None + mock_config.key_file = None + mock_config.verify_ssl = True + mock_config.connection_pool_maxsize = 4 + mock_config.timeout_millisec = 5000 + mock_config.proxy = None + mock_config.proxy_headers = None + + client = RESTClientObject(configuration=mock_config) + mock_pool_manager = MagicMock() + client.pool_manager = mock_pool_manager + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.reason = "OK" + mock_response.stream.return_value = iter([b'{"ok":true}\n']) + + mock_pool_manager.request.return_value = mock_response + + results = list(client.stream("GET", "http://example.com")) + + assert results == [{"ok": True}] + mock_response.release_conn.assert_called_once() + + # Tests for SSL Context Reuse (fix for OpenSSL 3.0+ performance issues) @patch("ssl.create_default_context") @patch("urllib3.PoolManager") @@ -731,3 +821,132 @@ def test_ssl_context_with_all_ssl_options(mock_pool_manager, mock_create_context call_kwargs = mock_pool_manager.call_args[1] assert call_kwargs["ssl_context"] == mock_ssl_context assert call_kwargs["maxsize"] == 8 + + +def test_request_does_not_clear_pool_after_request(): + """Ensure request() does not call close()/pool_manager.clear() after each request, + so pooled connections are preserved for reuse.""" + mock_config = MagicMock( + spec=[ + "verify_ssl", + "ssl_ca_cert", + "cert_file", + "key_file", + "assert_hostname", + "retries", + "socket_options", + "connection_pool_maxsize", + "timeout_millisec", + "proxy", + "proxy_headers", + ] + ) + mock_config.ssl_ca_cert = None + mock_config.cert_file = None + mock_config.key_file = None + mock_config.verify_ssl = True + mock_config.connection_pool_maxsize = 4 + mock_config.timeout_millisec = 5000 + mock_config.proxy = None + mock_config.proxy_headers = None + + client = RESTClientObject(configuration=mock_config) + mock_pool_manager = MagicMock() + client.pool_manager = mock_pool_manager + + mock_raw_response = MagicMock() + mock_raw_response.status = 200 + mock_raw_response.reason = "OK" + mock_raw_response.data = b'{"ok":true}' + + mock_pool_manager.request.return_value = mock_raw_response + + # Make multiple requests + client.request(method="GET", url="http://example.com", _preload_content=True) + client.request(method="GET", url="http://example.com", _preload_content=True) + + # pool_manager.clear() should never have been called + mock_pool_manager.clear.assert_not_called() + + +def test_request_closes_response_on_error(): + """Ensure that if handle_response_exception raises, the raw response is closed + so the connection is not leaked from the pool.""" + mock_config = MagicMock( + spec=[ + "verify_ssl", + "ssl_ca_cert", + "cert_file", + "key_file", + "assert_hostname", + "retries", + "socket_options", + "connection_pool_maxsize", + "timeout_millisec", + "proxy", + "proxy_headers", + ] + ) + mock_config.ssl_ca_cert = None + mock_config.cert_file = None + mock_config.key_file = None + mock_config.verify_ssl = True + mock_config.connection_pool_maxsize = 4 + mock_config.timeout_millisec = 5000 + mock_config.proxy = None + mock_config.proxy_headers = None + + client = RESTClientObject(configuration=mock_config) + mock_pool_manager = MagicMock() + client.pool_manager = mock_pool_manager + + mock_raw_response = MagicMock() + mock_raw_response.status = 500 + mock_raw_response.reason = "Internal Server Error" + mock_raw_response.data = b'{"error":"something went wrong"}' + mock_raw_response.getheaders.return_value = {} + + mock_pool_manager.request.return_value = mock_raw_response + + with pytest.raises(ServiceException): + client.request(method="GET", url="http://example.com", _preload_content=True) + + # Verify raw_response.close() was called to release the connection + mock_raw_response.close.assert_called_once() + + +def test_api_client_close_calls_rest_client_close(): + """Ensure the sync ApiClient.close() delegates to rest_client.close() + so pooled connections are cleaned up at shutdown.""" + from openfga_sdk.configuration import Configuration + from openfga_sdk.sync.api_client import ApiClient + + configuration = Configuration( + api_url="http://api.fga.example", + ) + + api_client = ApiClient(configuration) + mock_rest_client = MagicMock() + api_client.rest_client = mock_rest_client + + api_client.close() + + mock_rest_client.close.assert_called_once() + + +def test_api_client_context_manager_calls_close(): + """Ensure the sync ApiClient context manager calls close() on exit, + which in turn cleans up the REST client's connection pool.""" + from openfga_sdk.configuration import Configuration + from openfga_sdk.sync.api_client import ApiClient + + configuration = Configuration( + api_url="http://api.fga.example", + ) + + with ApiClient(configuration) as api_client: + mock_rest_client = MagicMock() + api_client.rest_client = mock_rest_client + + # After exiting the context manager, close() should have been called + mock_rest_client.close.assert_called_once()