Skip to content
This repository was archived by the owner on Jun 23, 2023. It is now read-only.

Commit fbd1c62

Browse files
committed
Add get_claims_from_request
1 parent f091c20 commit fbd1c62

4 files changed

Lines changed: 87 additions & 44 deletions

File tree

src/oidcop/session/claims.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ class ClaimsInterface:
3131
def __init__(self, server_get):
3232
self.server_get = server_get
3333

34-
def authorization_request_claims(self,
35-
session_id: str,
36-
claims_release_point: Optional[str] = "") -> dict:
37-
_grant = self.server_get("endpoint_context").session_manager.get_grant(session_id)
38-
if _grant.authorization_request and "claims" in _grant.authorization_request:
39-
return _grant.authorization_request["claims"].get(claims_release_point, {})
34+
def authorization_request_claims(
35+
self,
36+
authorization_request: dict,
37+
claims_release_point: Optional[str] = "",
38+
) -> dict:
39+
if authorization_request and "claims" in authorization_request:
40+
return authorization_request["claims"].get(claims_release_point, {})
4041

4142
return {}
4243

@@ -70,16 +71,13 @@ def _get_module(self, usage, endpoint_context):
7071

7172
return module
7273

73-
def get_claims(self, session_id: str, scopes: str, claims_release_point: str) -> dict:
74-
"""
75-
76-
:param session_id: Session identifier
77-
:param scopes: Scopes
78-
:param claims_release_point: Where to release the claims. One of
79-
"userinfo"/"id_token"/"introspection"/"access_token"
80-
:return: Claims specification as a dictionary.
81-
"""
82-
74+
def get_claims_from_request(
75+
self,
76+
auth_req: dict,
77+
claims_release_point: str,
78+
scopes: str = None,
79+
client_id: str = None,
80+
):
8381
_context = self.server_get("endpoint_context")
8482
# which endpoint module configuration to get the base claims from
8583
module = self._get_module(claims_release_point, _context)
@@ -89,7 +87,8 @@ def get_claims(self, session_id: str, scopes: str, claims_release_point: str) ->
8987
else:
9088
return {}
9189

92-
user_id, client_id, grant_id = _context.session_manager.decrypt_session_id(session_id)
90+
if not client_id:
91+
client_id = auth_req.get("client_id")
9392

9493
# Can there be per client specification of which claims to use.
9594
if module.kwargs.get("enable_claims_per_client"):
@@ -112,28 +111,76 @@ def get_claims(self, session_id: str, scopes: str, claims_release_point: str) ->
112111
add_claims_by_scope = module.kwargs.get("add_claims_by_scope")
113112

114113
if add_claims_by_scope:
114+
if scopes is None:
115+
scopes = auth_req.get("scopes")
115116
if scopes:
116117
_claims = _context.scopes_handler.scopes_to_claims(scopes, client_id=client_id)
117118
claims.update(_claims)
118119

119120
# Bring in claims specification from the authorization request
120121
# This only goes for ID Token and user info
121-
request_claims = self.authorization_request_claims(session_id=session_id,
122-
claims_release_point=claims_release_point)
122+
request_claims = self.authorization_request_claims(
123+
authorization_request=auth_req,
124+
claims_release_point=claims_release_point
125+
)
123126

124127
# This will add claims that has not be added before and
125-
# set filters on those claims that also appears in one of the sources above
128+
# set filters on those claims that also appears in one of the sources
129+
# above
126130
if request_claims:
127131
claims.update(request_claims)
128132

129133
return claims
130134

131-
def get_claims_all_usage(self, session_id: str, scopes: str) -> dict:
135+
def get_claims(self, session_id: str, scopes: str, claims_release_point: str) -> dict:
136+
"""
137+
138+
:param session_id: Session identifier
139+
:param scopes: Scopes
140+
:param claims_release_point: Where to release the claims. One of
141+
"userinfo"/"id_token"/"introspection"/"access_token"
142+
:return: Claims specification as a dictionary.
143+
"""
144+
_context = self.server_get("endpoint_context")
145+
session_info = _context.session_manager.get_session_info(
146+
session_id, grant=True
147+
)
148+
client_id = session_info["client_id"]
149+
grant = session_info["grant"]
150+
151+
if grant.authorization_request:
152+
auth_req = grant.authorization_request
153+
else:
154+
auth_req = {}
155+
claims = self.get_claims_from_request(
156+
auth_req=auth_req,
157+
claims_release_point=claims_release_point,
158+
scopes=scopes,
159+
client_id=client_id,
160+
)
161+
162+
return claims
163+
164+
def get_claims_all_usage_from_request(
165+
self, auth_req: dict, scopes: str
166+
) -> dict:
132167
_claims = {}
133168
for usage in self.claims_release_points:
134-
_claims[usage] = self.get_claims(session_id, scopes, usage)
169+
_claims[usage] = self.get_claims_from_request(
170+
auth_req, usage, scopes
171+
)
135172
return _claims
136173

174+
def get_claims_all_usage(self, session_id: str, scopes: str) -> dict:
175+
grant = self.server_get(
176+
"endpoint_context"
177+
).session_manager.get_grant(session_id)
178+
if grant.authorization_request:
179+
auth_req = grant.authorization_request
180+
else:
181+
auth_req = {}
182+
return self.get_claims_all_usage_from_request(auth_req, scopes)
183+
137184
def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict:
138185
"""
139186

tests/test_01_claims.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,33 +141,24 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""):
141141
)
142142

143143
def test_authorization_request_id_token_claims(self):
144-
session_id = self._create_session(AREQ)
145-
146-
claims = self.claims_interface.authorization_request_claims(session_id, "id_token")
144+
claims = self.claims_interface.authorization_request_claims(AREQ, "id_token")
147145
assert claims == {}
148146

149147
def test_authorization_request_id_token_claims_2(self):
150-
session_id = self._create_session(AREQ_2)
151-
claims = self.claims_interface.authorization_request_claims(session_id, "id_token")
148+
claims = self.claims_interface.authorization_request_claims(AREQ_2, "id_token")
152149
assert claims
153150
assert set(claims.keys()) == {"nickname"}
154151

155152
def test_authorization_request_userinfo_claims(self):
156-
session_id = self._create_session(AREQ)
157-
158-
claims = self.claims_interface.authorization_request_claims(session_id, "userinfo")
153+
claims = self.claims_interface.authorization_request_claims(AREQ, "userinfo")
159154
assert claims == {}
160155

161156
def test_authorization_request_userinfo_claims_2(self):
162-
session_id = self._create_session(AREQ_2)
163-
164-
claims = self.claims_interface.authorization_request_claims(session_id, "userinfo")
157+
claims = self.claims_interface.authorization_request_claims(AREQ_2, "userinfo")
165158
assert claims == {}
166159

167160
def test_authorization_request_userinfo_claims_3(self):
168-
session_id = self._create_session(AREQ_3)
169-
170-
claims = self.claims_interface.authorization_request_claims(session_id, "userinfo")
161+
claims = self.claims_interface.authorization_request_claims(AREQ_3, "userinfo")
171162
assert set(claims.keys()) == {"name", "email", "email_verified"}
172163

173164
@pytest.mark.parametrize("usage", ["id_token", "userinfo", "introspection", "token"])

tests/test_06_session_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,12 @@ def _mint_token(self, token_class, grant, session_id, based_on=None):
200200
)
201201

202202
def test_grant(self):
203-
grant = Grant()
203+
sid = self._create_session(AUTH_REQ)
204+
grant = self.session_manager.get_grant(sid)
204205
assert grant.issued_token == []
205206
assert grant.is_active() is True
206207

207-
code = self._mint_token("authorization_code", grant, self.dummy_session_id)
208+
code = self._mint_token("authorization_code", grant, sid)
208209
assert isinstance(code, AuthorizationCode)
209210
assert code.is_active()
210211
assert len(grant.issued_token) == 1

tests/test_24_oidc_authorization_endpoint.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -825,14 +825,18 @@ def test_verify_response_type(self):
825825

826826
@pytest.mark.parametrize("exp_in", [360, "360", 0])
827827
def test_mint_token_exp_at(self, exp_in):
828-
grant = Grant()
829-
grant.usage_rules = {"authorization_code": {"expires_in": exp_in}}
830-
831-
DUMMY_SESSION_ID = self.session_manager.encrypted_session_id(
832-
"user_id", "client_id", "grant.id"
828+
request = AuthorizationRequest(
829+
client_id="client_1",
830+
response_type=["code"],
831+
redirect_uri="https://example.com/cb",
832+
state="state",
833+
scope="openid",
833834
)
835+
sid = self._create_session(request)
836+
grant = self.session_manager.get_grant(sid)
837+
grant.usage_rules = {"authorization_code": {"expires_in": exp_in}}
834838

835-
code = self.endpoint.mint_token("authorization_code", grant, DUMMY_SESSION_ID)
839+
code = self.endpoint.mint_token("authorization_code", grant, sid)
836840
if exp_in in [360, "360"]:
837841
assert code.expires_at
838842
else:

0 commit comments

Comments
 (0)