@@ -31,12 +31,13 @@ class ClaimsInterface:
3131 def __init__ (self , server_get ):
3232 self .server_get = server_get
3333
34- def authorization_request_claims (self ,
35- session_id : str ,
36- claims_release_point : Optional [str ] = "" ) -> dict :
37- _grant = self .server_get ("endpoint_context" ).session_manager .get_grant (session_id )
38- if _grant .authorization_request and "claims" in _grant .authorization_request :
39- return _grant .authorization_request ["claims" ].get (claims_release_point , {})
34+ def authorization_request_claims (
35+ self ,
36+ authorization_request : dict ,
37+ claims_release_point : Optional [str ] = "" ,
38+ ) -> dict :
39+ if authorization_request and "claims" in authorization_request :
40+ return authorization_request ["claims" ].get (claims_release_point , {})
4041
4142 return {}
4243
@@ -70,16 +71,13 @@ def _get_module(self, usage, endpoint_context):
7071
7172 return module
7273
73- def get_claims (self , session_id : str , scopes : str , claims_release_point : str ) -> dict :
74- """
75-
76- :param session_id: Session identifier
77- :param scopes: Scopes
78- :param claims_release_point: Where to release the claims. One of
79- "userinfo"/"id_token"/"introspection"/"access_token"
80- :return: Claims specification as a dictionary.
81- """
82-
74+ def get_claims_from_request (
75+ self ,
76+ auth_req : dict ,
77+ claims_release_point : str ,
78+ scopes : str = None ,
79+ client_id : str = None ,
80+ ):
8381 _context = self .server_get ("endpoint_context" )
8482 # which endpoint module configuration to get the base claims from
8583 module = self ._get_module (claims_release_point , _context )
@@ -89,7 +87,8 @@ def get_claims(self, session_id: str, scopes: str, claims_release_point: str) ->
8987 else :
9088 return {}
9189
92- user_id , client_id , grant_id = _context .session_manager .decrypt_session_id (session_id )
90+ if not client_id :
91+ client_id = auth_req .get ("client_id" )
9392
9493 # Can there be per client specification of which claims to use.
9594 if module .kwargs .get ("enable_claims_per_client" ):
@@ -112,28 +111,76 @@ def get_claims(self, session_id: str, scopes: str, claims_release_point: str) ->
112111 add_claims_by_scope = module .kwargs .get ("add_claims_by_scope" )
113112
114113 if add_claims_by_scope :
114+ if scopes is None :
115+ scopes = auth_req .get ("scopes" )
115116 if scopes :
116117 _claims = _context .scopes_handler .scopes_to_claims (scopes , client_id = client_id )
117118 claims .update (_claims )
118119
119120 # Bring in claims specification from the authorization request
120121 # This only goes for ID Token and user info
121- request_claims = self .authorization_request_claims (session_id = session_id ,
122- claims_release_point = claims_release_point )
122+ request_claims = self .authorization_request_claims (
123+ authorization_request = auth_req ,
124+ claims_release_point = claims_release_point
125+ )
123126
124127 # This will add claims that has not be added before and
125- # set filters on those claims that also appears in one of the sources above
128+ # set filters on those claims that also appears in one of the sources
129+ # above
126130 if request_claims :
127131 claims .update (request_claims )
128132
129133 return claims
130134
131- def get_claims_all_usage (self , session_id : str , scopes : str ) -> dict :
135+ def get_claims (self , session_id : str , scopes : str , claims_release_point : str ) -> dict :
136+ """
137+
138+ :param session_id: Session identifier
139+ :param scopes: Scopes
140+ :param claims_release_point: Where to release the claims. One of
141+ "userinfo"/"id_token"/"introspection"/"access_token"
142+ :return: Claims specification as a dictionary.
143+ """
144+ _context = self .server_get ("endpoint_context" )
145+ session_info = _context .session_manager .get_session_info (
146+ session_id , grant = True
147+ )
148+ client_id = session_info ["client_id" ]
149+ grant = session_info ["grant" ]
150+
151+ if grant .authorization_request :
152+ auth_req = grant .authorization_request
153+ else :
154+ auth_req = {}
155+ claims = self .get_claims_from_request (
156+ auth_req = auth_req ,
157+ claims_release_point = claims_release_point ,
158+ scopes = scopes ,
159+ client_id = client_id ,
160+ )
161+
162+ return claims
163+
164+ def get_claims_all_usage_from_request (
165+ self , auth_req : dict , scopes : str
166+ ) -> dict :
132167 _claims = {}
133168 for usage in self .claims_release_points :
134- _claims [usage ] = self .get_claims (session_id , scopes , usage )
169+ _claims [usage ] = self .get_claims_from_request (
170+ auth_req , usage , scopes
171+ )
135172 return _claims
136173
174+ def get_claims_all_usage (self , session_id : str , scopes : str ) -> dict :
175+ grant = self .server_get (
176+ "endpoint_context"
177+ ).session_manager .get_grant (session_id )
178+ if grant .authorization_request :
179+ auth_req = grant .authorization_request
180+ else :
181+ auth_req = {}
182+ return self .get_claims_all_usage_from_request (auth_req , scopes )
183+
137184 def get_user_claims (self , user_id : str , claims_restriction : dict ) -> dict :
138185 """
139186
0 commit comments