Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions openfga_sdk/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
import json
import logging
import re
import ssl
import urllib

Expand Down Expand Up @@ -221,7 +220,7 @@ async def build_request(
args["url"] = f"{url}?{encoded_qs}"

if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]:
if re.search("json", headers["Content-Type"], re.IGNORECASE):
if "json" in headers["Content-Type"].lower():
if body is not None:
body = json.dumps(body)
args["data"] = body
Expand Down Expand Up @@ -397,11 +396,15 @@ async def stream(
if isinstance(response, aiohttp.ClientResponse):
logger.debug("response body: %s", buffer.decode("utf-8"))

# Handle any HTTP errors that may have occurred
await self.handle_response_exception(response)

# Release the response object (required!)
response.release()
try:
# Handle any HTTP errors that may have occurred
await self.handle_response_exception(response)
finally:
# Release the response object back to the connection pool.
# This must always run, even if handle_response_exception raises,
# to avoid leaking the connection (preload_content=False means
# the connection is not auto-released).
response.release()

async def request(
self,
Expand Down
1 change: 1 addition & 0 deletions openfga_sdk/sync/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __exit__(self, exc_type, exc_value, traceback):
self.close()

def close(self):
self.rest_client.close()
if self._pool:
self._pool.close()
self._pool.join()
Expand Down
30 changes: 18 additions & 12 deletions openfga_sdk/sync/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
import json
import logging
import re
import ssl
import urllib

Expand Down Expand Up @@ -279,7 +278,7 @@ def build_request(

# Handle body/post_params for methods that send payloads
if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]:
if re.search("json", headers["Content-Type"], re.IGNORECASE):
if "json" in headers["Content-Type"].lower():
if body is not None:
body = json.dumps(body)
args["body"] = body
Expand Down Expand Up @@ -437,11 +436,15 @@ def stream(
except json.JSONDecodeError:
logger.debug("Incomplete leftover data at end of stream.")

# Handle any HTTP errors that may have occurred
self.handle_response_exception(response)

# Release the response object (required!)
response.release_conn()
try:
# Handle any HTTP errors that may have occurred
self.handle_response_exception(response)
finally:
# Release the response object back to the connection pool.
# This must always run, even if handle_response_exception raises,
# to avoid leaking the connection (preload_content=False means
# urllib3 does not auto-release).
response.release_conn()

def request(
self,
Expand Down Expand Up @@ -494,10 +497,13 @@ def request(
# Log the response body
logger.debug("response body: %s", wrapped_response.data.decode("utf-8"))

# Handle any errors that may have occurred
self.handle_response_exception(raw_response)

# Release the connection back to the pool
self.close()
# Handle any errors that may have occurred. If an exception is raised,
# ensure the underlying response is closed so the connection is not
# leaked from the pool.
try:
self.handle_response_exception(raw_response)
except Exception:
raw_response.close()
raise

return wrapped_response or raw_response
5 changes: 3 additions & 2 deletions openfga_sdk/telemetry/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def fromRequest(
_attributes[TelemetryAttributes.http_request_method] = http_method

if url is not None:
_hostname = urllib.parse.urlparse(url).hostname
_scheme = urllib.parse.urlparse(url).scheme
_parsed_url = urllib.parse.urlparse(url)
_hostname = _parsed_url.hostname
_scheme = _parsed_url.scheme

if type(_hostname) is str:
_attributes[TelemetryAttributes.http_host] = _hostname
Expand Down
6 changes: 4 additions & 2 deletions openfga_sdk/validation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import re


_ULID_REGEX = re.compile("^[0-7][0-9A-HJKMNP-TV-Z]{25}$")


def is_well_formed_ulid_string(ulid):
regex = re.compile("^[0-7][0-9A-HJKMNP-TV-Z]{25}$")
if not isinstance(ulid, str):
return False
match = regex.match(ulid)
match = _ULID_REGEX.match(ulid)
if match is None:
return False
return True
95 changes: 95 additions & 0 deletions test/rest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,98 @@ async def iter_chunks(self):

client.handle_response_exception.assert_awaited_once()
mock_response.release.assert_called_once()


@pytest.mark.asyncio
async def test_stream_releases_conn_on_error_status():
"""Ensure release() is called even when handle_response_exception raises,
so the connection is returned to the pool and not leaked."""
mock_config = MagicMock()
mock_config.ssl_ca_cert = None
mock_config.cert_file = None
mock_config.key_file = None
mock_config.verify_ssl = True
mock_config.connection_pool_maxsize = 4
mock_config.proxy = None
mock_config.proxy_headers = None
mock_config.timeout_millisec = 5000

client = RESTClientObject(configuration=mock_config)
mock_session = MagicMock()
client.pool_manager = mock_session

class FakeContent:
async def iter_chunks(self):
yield (b'{"ok":true}\n', None)

mock_response = MagicMock()
mock_response.status = 500
mock_response.reason = "Internal Server Error"
mock_response.data = None
mock_response.content = FakeContent()

mock_context_manager = AsyncMock()
mock_context_manager.__aenter__.return_value = mock_response
mock_context_manager.__aexit__.return_value = None

mock_session.request.return_value = mock_context_manager

# Make handle_response_exception raise an exception
client.handle_response_exception = AsyncMock(
side_effect=ServiceException(status=500, reason="Internal Server Error")
)
client.close = AsyncMock()

results = []
with pytest.raises(ServiceException):
async for item in client.stream("GET", "http://example.com"):
results.append(item)

# The critical assertion: release() must be called even though
# handle_response_exception raised ServiceException
mock_response.release.assert_called_once()


@pytest.mark.asyncio
async def test_stream_releases_conn_on_success():
"""Ensure release() is called on successful stream completion."""
mock_config = MagicMock()
mock_config.ssl_ca_cert = None
mock_config.cert_file = None
mock_config.key_file = None
mock_config.verify_ssl = True
mock_config.connection_pool_maxsize = 4
mock_config.proxy = None
mock_config.proxy_headers = None
mock_config.timeout_millisec = 5000

client = RESTClientObject(configuration=mock_config)
mock_session = MagicMock()
client.pool_manager = mock_session

class FakeContent:
async def iter_chunks(self):
yield (b'{"ok":true}\n', None)

mock_response = MagicMock()
mock_response.status = 200
mock_response.reason = "OK"
mock_response.data = None
mock_response.content = FakeContent()

mock_context_manager = AsyncMock()
mock_context_manager.__aenter__.return_value = mock_response
mock_context_manager.__aexit__.return_value = None

mock_session.request.return_value = mock_context_manager

client.handle_response_exception = AsyncMock()
client.close = AsyncMock()

results = []
async for item in client.stream("GET", "http://example.com"):
results.append(item)

assert results == [{"ok": True}]
client.handle_response_exception.assert_awaited_once()
mock_response.release.assert_called_once()
Loading
Loading