|
14 | 14 | from cryptojwt.utils import as_bytes
|
15 | 15 | from cryptojwt.utils import b64e
|
16 | 16 |
|
| 17 | +from idpyoidc.exception import ImproperlyConfigured |
17 | 18 | from idpyoidc.exception import ParameterError
|
18 | 19 | from idpyoidc.exception import URIError
|
19 | 20 | from idpyoidc.message import Message
|
|
39 | 40 | from idpyoidc.time_util import utc_time_sans_frac
|
40 | 41 | from idpyoidc.util import rndstr
|
41 | 42 | from idpyoidc.util import split_uri
|
| 43 | +from idpyoidc.util import importer |
42 | 44 |
|
43 | 45 | logger = logging.getLogger(__name__)
|
44 | 46 |
|
@@ -277,6 +279,53 @@ def check_unknown_scopes_policy(request_info, client_id, endpoint_context):
|
277 | 279 | raise UnAuthorizedClientScope()
|
278 | 280 |
|
279 | 281 |
|
| 282 | +def validate_resource_indicators_policy(request, context, **kwargs): |
| 283 | + if "resource" not in request: |
| 284 | + return oauth2.AuthorizationErrorResponse( |
| 285 | + error="invalid_target", |
| 286 | + error_description="Missing resource parameter", |
| 287 | + ) |
| 288 | + |
| 289 | + resource_servers_per_client = kwargs["resource_servers_per_client"] |
| 290 | + client_id = request["client_id"] |
| 291 | + |
| 292 | + if isinstance(resource_servers_per_client, dict) and client_id not in resource_servers_per_client: |
| 293 | + return oauth2.AuthorizationErrorResponse( |
| 294 | + error="invalid_target", |
| 295 | + error_description=f"Resources for client {client_id} not found", |
| 296 | + ) |
| 297 | + |
| 298 | + if isinstance(resource_servers_per_client, dict): |
| 299 | + permitted_resources = [res for res in resource_servers_per_client[client_id]] |
| 300 | + else: |
| 301 | + permitted_resources = [res for res in resource_servers_per_client] |
| 302 | + |
| 303 | + common_resources = list(set(request["resource"]).intersection(set(permitted_resources))) |
| 304 | + if not common_resources: |
| 305 | + return oauth2.AuthorizationErrorResponse( |
| 306 | + error="invalid_target", |
| 307 | + error_description=f"Invalid resource requested by client {client_id}", |
| 308 | + ) |
| 309 | + |
| 310 | + common_resources = [r for r in common_resources if r in context.cdb.keys()] |
| 311 | + if not common_resources: |
| 312 | + return oauth2.AuthorizationErrorResponse( |
| 313 | + error="invalid_target", |
| 314 | + error_description=f"Invalid resource requested by client {client_id}", |
| 315 | + ) |
| 316 | + |
| 317 | + if client_id not in common_resources: |
| 318 | + common_resources.append(client_id) |
| 319 | + |
| 320 | + request["resource"] = common_resources |
| 321 | + |
| 322 | + permitted_scopes = [context.cdb[r]["allowed_scopes"] for r in common_resources] |
| 323 | + permitted_scopes = [r for res in permitted_scopes for r in res] |
| 324 | + scopes = list(set(request.get("scope", [])).intersection(set(permitted_scopes))) |
| 325 | + request["scope"] = scopes |
| 326 | + return request |
| 327 | + |
| 328 | + |
280 | 329 | class Authorization(Endpoint):
|
281 | 330 | request_cls = oauth2.AuthorizationRequest
|
282 | 331 | response_cls = oauth2.AuthorizationResponse
|
@@ -304,6 +353,8 @@ class Authorization(Endpoint):
|
304 | 353 |
|
305 | 354 | def __init__(self, server_get, **kwargs):
|
306 | 355 | Endpoint.__init__(self, server_get, **kwargs)
|
| 356 | + |
| 357 | + self.resource_indicators_config = kwargs.get("resource_indicators", None) |
307 | 358 | self.post_parse_request.append(self._do_request_uri)
|
308 | 359 | self.post_parse_request.append(self._post_parse_request)
|
309 | 360 | self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS)
|
@@ -461,8 +512,45 @@ def _post_parse_request(self, request, client_id, endpoint_context, **kwargs):
|
461 | 512 | else:
|
462 | 513 | request["redirect_uri"] = redirect_uri
|
463 | 514 |
|
| 515 | + if ("resource_indicators" in _cinfo |
| 516 | + and "authorization_code" in _cinfo["resource_indicators"]): |
| 517 | + resource_indicators_config = _cinfo["resource_indicators"]["authorization_code"] |
| 518 | + else: |
| 519 | + resource_indicators_config = self.resource_indicators_config |
| 520 | + |
| 521 | + if resource_indicators_config is not None: |
| 522 | + if "policy" not in resource_indicators_config: |
| 523 | + policy = {"policy": {"callable": validate_resource_indicators_policy}} |
| 524 | + resource_indicators_config.update(policy) |
| 525 | + request = self._enforce_resource_indicators_policy(request, resource_indicators_config) |
| 526 | + |
464 | 527 | return request
|
465 | 528 |
|
| 529 | + def _enforce_resource_indicators_policy(self, request, config): |
| 530 | + _context = self.server_get("endpoint_context") |
| 531 | + |
| 532 | + policy = config["policy"] |
| 533 | + callable = policy["callable"] |
| 534 | + kwargs = policy.get("kwargs", {}) |
| 535 | + |
| 536 | + if kwargs.get("resource_servers_per_client", None) is None: |
| 537 | + kwargs["resource_servers_per_client"] = { |
| 538 | + request["client_id"]: request["client_id"] |
| 539 | + } |
| 540 | + |
| 541 | + if isinstance(callable, str): |
| 542 | + try: |
| 543 | + fn = importer(callable) |
| 544 | + except Exception: |
| 545 | + raise ImproperlyConfigured(f"Error importing {callable} policy callable") |
| 546 | + else: |
| 547 | + fn = callable |
| 548 | + try: |
| 549 | + return fn(request, context=_context, **kwargs) |
| 550 | + except Exception as e: |
| 551 | + logger.error(f"Error while executing the {fn} policy callable: {e}") |
| 552 | + return self.error_cls(error="server_error", error_description="Internal server error") |
| 553 | + |
466 | 554 | def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs):
|
467 | 555 | _context = self.server_get("endpoint_context")
|
468 | 556 | auth_id = kwargs.get("auth_method_id")
|
@@ -750,10 +838,17 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict
|
750 | 838 | _mngr = _context.session_manager
|
751 | 839 | _sinfo = _mngr.get_session_info(sid, grant=True)
|
752 | 840 |
|
| 841 | + scope = [] |
| 842 | + resource_scopes = [] |
753 | 843 | if request.get("scope"):
|
754 |
| - aresp["scope"] = _context.scopes_handler.filter_scopes( |
755 |
| - request["scope"], _sinfo["client_id"] |
756 |
| - ) |
| 844 | + scope = request.get("scope") |
| 845 | + if request.get("resource"): |
| 846 | + resource_scopes = [_context.cdb[s]["scope"] for s in request.get("resource") if s in _context.cdb.keys() and _context.cdb[s].get("scope")] |
| 847 | + resource_scopes = [item for sublist in resource_scopes for item in sublist] |
| 848 | + |
| 849 | + aresp["scope"] = _context.scopes_handler.filter_scopes( |
| 850 | + list(set(scope+resource_scopes)), _sinfo["client_id"] |
| 851 | + ) |
757 | 852 |
|
758 | 853 | rtype = set(request["response_type"][:])
|
759 | 854 | handled_response_type = []
|
|
0 commit comments