Skip to content
This repository was archived by the owner on Jun 23, 2023. It is now read-only.

Commit 08c572c

Browse files
author
Giuseppe De Marco
authored
Merge pull request #66 from IdentityPython/pick_auth
chore: pick_auth refactor
2 parents 524e4e9 + c69fbfc commit 08c572c

3 files changed

Lines changed: 40 additions & 48 deletions

File tree

src/oidcop/oauth2/authorization.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,16 @@ def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs):
429429
if auth_id:
430430
return _context.authn_broker[auth_id]
431431

432+
res = None
432433
if acr:
433434
res = _context.authn_broker.pick(acr)
434435
else:
435-
res = pick_auth(_context, request)
436-
436+
try:
437+
res = pick_auth(_context, request)
438+
except Exception as exc:
439+
logger.exception(
440+
f"An error occurred while picking the authN broker: {exc}"
441+
)
437442
if res:
438443
return res
439444
else:

src/oidcop/user_authn/authn_context.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -108,58 +108,45 @@ def default(self):
108108
return None
109109

110110

111-
def pick_auth(endpoint_context, areq, all=False):
111+
def pick_auth(endpoint_context, areq, pick_all=False):
112112
"""
113113
Pick authentication method
114114
115115
:param areq: AuthorizationRequest instance
116116
:return: A dictionary with the authentication method and its authn class ref
117117
"""
118-
119118
acrs = []
120-
try:
121-
if len(endpoint_context.authn_broker) == 1:
122-
return endpoint_context.authn_broker.default()
123-
124-
if "acr_values" in areq:
125-
if not isinstance(areq["acr_values"], list):
126-
areq["acr_values"] = [areq["acr_values"]]
127-
acrs = areq["acr_values"]
128-
else: # same as any
129-
try:
130-
acrs = areq["claims"]["id_token"]["acr"]["values"]
131-
except KeyError:
132-
try:
133-
_ith = areq[verified_claim_name("id_token_hint")]
134-
except KeyError:
135-
try:
136-
_hint = areq["login_hint"]
137-
except KeyError:
138-
pass
139-
else:
140-
if endpoint_context.login_hint2acrs:
141-
acrs = endpoint_context.login_hint2acrs(_hint)
142-
else:
143-
try:
144-
acrs = [_ith["acr"]]
145-
except KeyError:
146-
pass
147-
148-
if not acrs:
149-
return endpoint_context.authn_broker.default()
150-
151-
for acr in acrs:
152-
res = endpoint_context.authn_broker.pick(acr)
153-
logger.debug("Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)))
154-
if res:
155-
if all:
156-
return res
157-
else:
158-
# Return the first guess by pick.
159-
return res[0]
160-
161-
except KeyError as exc:
162-
logger.debug("An error occurred while picking the authN broker: %s" % str(exc))
119+
if len(endpoint_context.authn_broker) == 1:
120+
return endpoint_context.authn_broker.default()
121+
122+
if "acr_values" in areq:
123+
if not isinstance(areq["acr_values"], list):
124+
areq["acr_values"] = [areq["acr_values"]]
125+
acrs = areq["acr_values"]
126+
127+
else:
128+
try:
129+
acrs = areq["claims"]["id_token"]["acr"]["values"]
130+
except KeyError:
131+
_ith = verified_claim_name("id_token_hint")
132+
if areq.get(_ith):
133+
_ith = areq[verified_claim_name("id_token_hint")]
134+
if _ith.get("acr"):
135+
acrs = [_ith["acr"]]
136+
else:
137+
if areq.get("login_hint") and endpoint_context.login_hint2acrs:
138+
acrs = endpoint_context.login_hint2acrs(areq["login_hint"])
139+
140+
if not acrs:
141+
return endpoint_context.authn_broker.default()
142+
143+
for acr in acrs:
144+
res = endpoint_context.authn_broker.pick(acr)
145+
logger.debug(
146+
f"Picked AuthN broker for ACR {str(acr)}: {str(res)}"
147+
)
148+
if res:
149+
return res if pick_all else res[0]
163150

164151
return None
165152

tests/test_06_authn_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_pick_authn_one(self):
179179

180180
def test_pick_authn_all(self):
181181
request = {"acr_values": INTERNETPROTOCOLPASSWORD}
182-
res = pick_auth(self.server.server_get("endpoint_context"), request, all=True)
182+
res = pick_auth(self.server.server_get("endpoint_context"), request, pick_all=True)
183183
assert len(res) == 2
184184

185185

0 commit comments

Comments
 (0)