Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/sumo/wrapper/_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Dict
from urllib.parse import parse_qs

import jwt
Expand Down Expand Up @@ -76,10 +77,10 @@ def get_token(self):
# ELSE
return result["access_token"]

def get_authorization(self):
def get_authorization(self) -> Dict:
token = self.get_token()
if token is None:
return ""
return {}

return {"Authorization": "Bearer " + token}

Expand All @@ -101,6 +102,13 @@ def has_case_token(self, case_uuid):
pass


class AuthProviderNone(AuthProvider):
def get_token(self):
raise Exception("No valid authorization provider found.")

pass


class AuthProviderSilent(AuthProvider):
def __init__(self, client_id, authority, resource_id):
super().__init__(resource_id)
Expand Down Expand Up @@ -423,7 +431,7 @@ def get_auth_provider(
refresh_token=None,
devicecode=False,
case_uuid=None,
):
) -> AuthProvider:
if refresh_token:
return AuthProviderRefreshToken(
refresh_token, client_id, authority, resource_id
Expand Down Expand Up @@ -472,6 +480,8 @@ def get_auth_provider(
]
):
return AuthProviderManaged(resource_id)
# ELSE
return AuthProviderNone(resource_id)


def cleanup_shared_keys():
Expand All @@ -481,7 +491,7 @@ def cleanup_shared_keys():
for f in os.listdir(tokendir):
ff = os.path.join(tokendir, f)
if os.path.isfile(ff):
(name, ext) = os.path.splitext(ff)
(_, ext) = os.path.splitext(ff)
if ext.lower() == ".sharedkey":
try:
with open(ff, "r") as file:
Expand Down
6 changes: 3 additions & 3 deletions src/sumo/wrapper/_logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from datetime import datetime
from datetime import datetime, timezone


class LogHandlerSumo(logging.Handler):
Expand All @@ -11,8 +11,8 @@ def __init__(self, sumo_client):
def emit(self, record):
try:
dt = (
datetime.now(datetime.timezone.utc)
.replace(microsecond=0)
datetime.now(timezone.utc)
.replace(microsecond=0, tzinfo=None)
.isoformat()
+ "Z"
)
Expand Down
4 changes: 2 additions & 2 deletions src/sumo/wrapper/_retry_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, stop_after=6, multiplier=0.5, exp_base=2):
self._exp_base = exp_base
return

def make_retryer(self):
def make_retryer(self) -> tn.Retrying:
return tn.Retrying(
stop=tn.stop_after_attempt(self._stop_after),
retry=(
Expand All @@ -63,7 +63,7 @@ def make_retryer(self):
before_sleep=_log_retry_info,
)

def make_retryer_async(self):
def make_retryer_async(self) -> tn.AsyncRetrying:
return tn.AsyncRetrying(
stop=tn.stop_after_attempt(self._stop_after),
retry=(
Expand Down
56 changes: 33 additions & 23 deletions src/sumo/wrapper/sumo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
import time
from typing import Dict, Optional, Tuple

import httpx
import jwt
Expand All @@ -25,10 +26,13 @@
class SumoClient:
"""Authenticate and perform requests to the Sumo API."""

_client: httpx.Client
_async_client: httpx.AsyncClient

def __init__(
self,
env: str,
token: str = None,
token: Optional[str] = None,
interactive: bool = False,
devicecode: bool = False,
verbosity: str = "CRITICAL",
Expand Down Expand Up @@ -119,26 +123,23 @@ def __init__(
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, *_):
if not self._borrowed_client:
self._client.close()
self._client = None
return False

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, *_):
if not self._borrowed_async_client:
await self._async_client.aclose()
self._async_client = None
return False

def __del__(self):
if self._client is not None and not self._borrowed_client:
self._client.close()
pass
self._client = None
if self._async_client is not None and not self._borrowed_async_client:

async def closeit(client):
Expand All @@ -151,7 +152,6 @@ async def closeit(client):
except RuntimeError:
pass
pass
self._async_client = None

def authenticate(self):
if self.auth is None:
Expand Down Expand Up @@ -186,7 +186,7 @@ def blob_client(self) -> BlobClient:
)

@raise_for_status
def get(self, path: str, params: dict = None) -> dict:
def get(self, path: str, params: Optional[Dict] = None) -> httpx.Response:
"""Performs a GET-request to the Sumo API.

Args:
Expand Down Expand Up @@ -247,9 +247,9 @@ def _get():
def post(
self,
path: str,
blob: bytes = None,
json: dict = None,
params: dict = None,
blob: Optional[bytes] = None,
json: Optional[dict] = None,
params: Optional[dict] = None,
) -> httpx.Response:
"""Performs a POST-request to the Sumo API.

Expand Down Expand Up @@ -320,7 +320,10 @@ def _post():

@raise_for_status
def put(
self, path: str, blob: bytes = None, json: dict = None
self,
path: str,
blob: Optional[bytes] = None,
json: Optional[dict] = None,
) -> httpx.Response:
"""Performs a PUT-request to the Sumo API.

Expand Down Expand Up @@ -365,7 +368,9 @@ def _put():
return retryer(_put)

@raise_for_status
def delete(self, path: str, params: dict = None) -> dict:
def delete(
self, path: str, params: Optional[dict] = None
) -> httpx.Response:
"""Performs a DELETE-request to the Sumo API.

Args:
Expand Down Expand Up @@ -402,12 +407,12 @@ def _delete():

return retryer(_delete)

def _get_retry_details(self, response_in):
def _get_retry_details(self, response_in) -> Tuple[str, int]:
assert response_in.status_code == 202, (
"Incorrect status code; expcted 202"
)
headers = response_in.headers
location = headers.get("location")
location: str = headers.get("location")
assert location is not None, "Missing header: Location"
assert location.startswith(self.base_url)
retry_after = headers.get("retry-after")
Expand Down Expand Up @@ -440,7 +445,6 @@ def poll(
)
location, retry_after = self._get_retry_details(response)
pass
return None # should never get here.

def getLogger(self, name):
"""Gets a logger object that sends log objects into the message_log
Expand Down Expand Up @@ -495,7 +499,9 @@ def client_for_case(self, case_uuid):
return self

@raise_for_status_async
async def get_async(self, path: str, params: dict = None):
async def get_async(
self, path: str, params: Optional[dict] = None
) -> httpx.Response:
"""Performs an async GET-request to the Sumo API.

Args:
Expand Down Expand Up @@ -556,9 +562,9 @@ async def _get():
async def post_async(
self,
path: str,
blob: bytes = None,
json: dict = None,
params: dict = None,
blob: Optional[bytes] = None,
json: Optional[dict] = None,
params: Optional[dict] = None,
) -> httpx.Response:
"""Performs an async POST-request to the Sumo API.

Expand Down Expand Up @@ -630,7 +636,10 @@ async def _post():

@raise_for_status_async
async def put_async(
self, path: str, blob: bytes = None, json: dict = None
self,
path: str,
blob: Optional[bytes] = None,
json: Optional[dict] = None,
) -> httpx.Response:
"""Performs an async PUT-request to the Sumo API.

Expand Down Expand Up @@ -675,7 +684,9 @@ async def _put():
return await retryer(_put)

@raise_for_status_async
async def delete_async(self, path: str, params: dict = None) -> dict:
async def delete_async(
self, path: str, params: Optional[dict] = None
) -> httpx.Response:
"""Performs an async DELETE-request to the Sumo API.

Args:
Expand Down Expand Up @@ -736,4 +747,3 @@ async def poll_async(
)
location, retry_after = self._get_retry_details(response)
pass
return None # should never get here.
Loading