diff --git a/msal/application.py b/msal/application.py index 3fc69461..f8ed857a 100644 --- a/msal/application.py +++ b/msal/application.py @@ -20,7 +20,7 @@ from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * -from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key +from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key, _parse_claims_or_raise, _merge_claims import msal.telemetry from .region import _detect_region, _validate_region from .throttled_http_client import ThrottledHttpClient @@ -64,6 +64,31 @@ def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge): return json.dumps(claims_dict) +def _stash_client_claims(forwarded_client_claims, data): + """Validate ``forwarded_client_claims`` and stash it into the request ``data``. + + ``forwarded_client_claims`` carries *client-originated* claims supplied by the + caller. The raw value is + stored in ``data`` (under the internal ``client_claims`` key) so that it + (a) contributes to the extended cache key -- isolating cache entries by + claims value -- and (b) is stripped from the request body by the oauth2 + layer (it reaches the wire only after being merged into the standard OAuth + ``claims`` parameter). ``data`` is mutated in place. + + Unlike ``claims_challenge`` (server-issued, which bypasses the cache), + ``forwarded_client_claims`` tokens are cached and keyed on the claims value. + A no-op when ``forwarded_client_claims`` is ``None``. + """ + if forwarded_client_claims is None: + return + if not isinstance(forwarded_client_claims, str): + raise ValueError( + "forwarded_client_claims must be a string, got {}".format( + type(forwarded_client_claims).__name__)) + _parse_claims_or_raise(forwarded_client_claims) # Fail fast on malformed JSON + data["client_claims"] = forwarded_client_claims + + def _str2bytes(raw): # A conversion based on duck-typing rather than six.text_type try: @@ -424,6 +449,15 @@ def get_client_assertion(): "jti": a_random_uuid } + .. note:: + + This *constructor* ``client_claims`` (a ``dict`` signed into the + client-assertion JWT) is distinct from the per-request + ``forwarded_client_claims`` parameter (a JSON string of client-originated + claims forwarded in the token request) accepted by + ``acquire_token_for_client``, ``acquire_token_on_behalf_of``, + ``acquire_token_silent``, and the other token-acquisition methods. + :param str authority: A URL that identifies a token authority. It should be of the format ``https://login.microsoftonline.com/your_tenant`` @@ -1237,6 +1271,7 @@ def acquire_token_by_authorization_code( # values MUST be identical. nonce=None, claims_challenge=None, + forwarded_client_claims=None, **kwargs): """The second half of the Authorization Code Grant. @@ -1267,6 +1302,18 @@ def acquire_token_by_authorization_code( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str forwarded_client_claims: + Optional. A JSON string of *client-originated* claims to include in + the token request. Unlike ``claims_challenge`` (server-issued, which + bypasses the cache), tokens acquired with ``forwarded_client_claims`` + **are cached** and keyed on the claims value. Send the *same* value on + every request that should share the cached token; omitting or changing + it routes to a different cache entry (a cache miss), so use stable, + non-dynamic values. The value is merged into the standard OAuth + ``claims`` request parameter sent on the wire. + + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: @@ -1286,14 +1333,18 @@ def acquire_token_by_authorization_code( with warnings.catch_warnings(record=True): telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) + _data = kwargs.pop("data", {}) + _stash_client_claims(forwarded_client_claims, _data) response = _clean_up(self.client.obtain_token_by_authorization_code( code, redirect_uri=redirect_uri, scope=self._decorate_scope(scopes), headers=telemetry_context.generate_headers(), data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + _data, + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), nonce=nonce, **kwargs)) if "access_token" in response: @@ -1474,6 +1525,7 @@ def acquire_token_silent( authority=None, # See get_authorization_request_url() force_refresh=False, # type: Optional[boolean] claims_challenge=None, + forwarded_client_claims=None, auth_scheme=None, **kwargs): """Acquire an access token for given account, without user interaction. @@ -1493,6 +1545,9 @@ def acquire_token_silent( """ if not account: return None # A backward-compatible NO-OP to drop the account=None usage + if forwarded_client_claims is not None: + kwargs["data"] = kwargs.get("data") or {} + _stash_client_claims(forwarded_client_claims, kwargs["data"]) result = _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, auth_scheme=auth_scheme, **kwargs)) @@ -1505,6 +1560,7 @@ def acquire_token_silent_with_error( authority=None, # See get_authorization_request_url() force_refresh=False, # type: Optional[boolean] claims_challenge=None, + forwarded_client_claims=None, auth_scheme=None, **kwargs): """Acquire an access token for given account, without user interaction. @@ -1532,6 +1588,15 @@ def acquire_token_silent_with_error( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str forwarded_client_claims: + Optional. A JSON string of *client-originated* claims to include when + a cached token is missing and a network request is made. Tokens are + **cached** and keyed on the claims value (different values yield + separate cache entries), so send the *same* value on every call that + should reuse the cached token, and use stable, non-dynamic values. + + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :param object auth_scheme: You can provide an ``msal.auth_scheme.PopAuthScheme`` object so that MSAL will get a Proof-of-Possession (POP) token for you. @@ -1547,6 +1612,9 @@ def acquire_token_silent_with_error( """ if not account: return None # A backward-compatible NO-OP to drop the account=None usage + if forwarded_client_claims is not None: + kwargs["data"] = kwargs.get("data") or {} + _stash_client_claims(forwarded_client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, auth_scheme=auth_scheme, **kwargs)) @@ -1809,6 +1877,9 @@ def _acquire_token_silent_by_finding_specific_refresh_token( telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_SILENT_ID, correlation_id=correlation_id, refresh_reason=refresh_reason) + # Pop "data" once (rather than per-iteration) so client_claims and any + # other data fields apply consistently across all candidate RTs. + _data = kwargs.pop("data", {}) for entry in sorted( # Since unfit RTs would not be aggressively removed, # we start from newer RTs which are more likely fit. matches, @@ -1832,9 +1903,11 @@ def _acquire_token_silent_by_finding_specific_refresh_token( scope=scopes, headers=headers, data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + _data, + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), **kwargs) telemetry_context.update_telemetry(response) if "error" not in response: @@ -2494,7 +2567,7 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app except that ``allow_broker`` parameter shall remain ``None``. """ - def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, **kwargs): + def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, forwarded_client_claims=None, **kwargs): """Acquires token for the current confidential client, not for an end user. Since MSAL Python 1.23, it will automatically look for token from cache, @@ -2518,6 +2591,21 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, scopes=["api://resource/.default"], fmi_path="SomeFmiPath/FmiCredentialPath", ) + :param str forwarded_client_claims: + Optional. A JSON string containing *client-originated* claims to + include in the token request. + + Unlike ``claims_challenge`` (which carries *server-issued* claims + challenges and bypasses the cache), tokens acquired with + ``forwarded_client_claims`` **are cached**, and the cache entry is keyed on the + claims value. Send the *same* value on every request that should share + the cached token; different values produce separate cache entries, so + use stable, non-dynamic values to avoid unbounded cache growth. The + value is merged into the standard OAuth ``claims`` request parameter + sent on the wire. + + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -2533,6 +2621,13 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, "fmi_path must be a string, got {}".format(type(fmi_path).__name__)) kwargs["data"] = kwargs.get("data", {}) kwargs["data"]["fmi_path"] = fmi_path + if forwarded_client_claims is not None: + # Carry it in the request data so it contributes to the extended + # cache key (different claims => separate cache entries). It is + # merged into the "claims" body parameter in _acquire_token_for_client + # and stripped from the wire body by the oauth2 layer. + kwargs["data"] = kwargs.get("data", {}) + _stash_client_claims(forwarded_client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, None, claims_challenge=claims_challenge, **kwargs)) @@ -2552,13 +2647,20 @@ def _acquire_token_for_client( telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_FOR_CLIENT_ID, refresh_reason=refresh_reason) client = self._regional_client or self.client + request_data = kwargs.pop("data", {}) + claims = _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge) + # Client-originated claims (set via forwarded_client_claims=) are merged into the + # same OAuth "claims" parameter and sent on the wire. The raw + # "client_claims" entry stays in request_data so it keys the cache; the + # oauth2 layer drops it from the actual request body. + client_claims = request_data.get("client_claims") + if client_claims: + claims = _merge_claims(claims, client_claims) response = client.obtain_token_for_client( scope=scopes, # This grant flow requires no scope decoration headers=telemetry_context.generate_headers(), - data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + data=dict(request_data, claims=claims), **kwargs) telemetry_context.update_telemetry(response) return response @@ -2577,7 +2679,7 @@ def remove_tokens_for_client(self): self.token_cache.remove_at(at) # acquire_token_for_client() obtains no RTs, so we have no RT to remove - def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): + def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, forwarded_client_claims=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. The current app is a middle-tier service which was called with a token @@ -2597,6 +2699,18 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str forwarded_client_claims: + Optional. A JSON string of *client-originated* claims to include in + the token request. Unlike ``claims_challenge`` (server-issued, which + bypasses the cache), tokens acquired with ``forwarded_client_claims`` + **are cached** and keyed on the claims value. Send the *same* value on + every request that should share the cached token; omitting or changing + it routes to a different cache entry (a cache miss), so use stable, + non-dynamic values. The value is merged into the standard OAuth + ``claims`` request parameter sent on the wire. + + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: @@ -2605,6 +2719,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No """ telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID) + _data = kwargs.pop("data", {}) + _stash_client_claims(forwarded_client_claims, _data) # The implementation is NOT based on Token Exchange (RFC 8693) response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 user_assertion, @@ -2616,10 +2732,12 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No # so that the calling app could use id_token_claims to implement # their own cache mapping, which is likely needed in web apps. data=dict( - kwargs.pop("data", {}), + _data, requested_token_use="on_behalf_of", - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), headers=telemetry_context.generate_headers(), # TBD: Expose a login_hint (or ccs_routing_hint) param for web app **kwargs)) @@ -2630,7 +2748,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No def acquire_token_by_user_federated_identity_credential( self, scopes, assertion, username=None, user_object_id=None, - claims_challenge=None, **kwargs): + claims_challenge=None, forwarded_client_claims=None, **kwargs): """Acquires a user-scoped token using the ``user_fic`` grant type. This method exchanges a federated identity credential (typically an @@ -2653,6 +2771,18 @@ def acquire_token_by_user_federated_identity_credential( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str forwarded_client_claims: + Optional. A JSON string of *client-originated* claims to include in + the token request. Unlike ``claims_challenge`` (server-issued, which + bypasses the cache), tokens acquired with ``forwarded_client_claims`` + **are cached** and keyed on the claims value. Send the *same* value on + every request that should share the cached token; omitting or changing + it routes to a different cache entry (a cache miss), so use stable, + non-dynamic values. The value is merged into the standard OAuth + ``claims`` request parameter sent on the wire. + + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: @@ -2677,6 +2807,8 @@ def acquire_token_by_user_federated_identity_credential( elif user_object_id: headers["X-AnchorMailbox"] = "Oid:{}@{}".format( user_object_id, self.authority.tenant) + _data = kwargs.pop("data", {}) + _stash_client_claims(forwarded_client_claims, _data) response = _clean_up(self.client.obtain_token_by_user_fic( scope=self._decorate_scope(scopes), assertion=assertion, @@ -2684,9 +2816,11 @@ def acquire_token_by_user_federated_identity_credential( user_object_id=user_object_id, headers=headers, data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + _data, + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), **kwargs)) if "access_token" in response: response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP diff --git a/msal/managed_identity.py b/msal/managed_identity.py index b2fc446c..de10e535 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -12,7 +12,7 @@ from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ from typing import List, Optional, Union # Needed in Python 3.7 & 3.8 -from .token_cache import TokenCache +from .token_cache import TokenCache, _compute_ext_cache_key, _parse_claims_or_raise from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser from .cloudshell import _is_running_in_cloud_shell @@ -26,6 +26,12 @@ class ManagedIdentityError(ValueError): pass +_CLIENT_CLAIMS_UNSUPPORTED_SOURCE = ( + "forwarded_client_claims is only supported for the IMDS (Azure VM) managed identity " + "source. The detected source ({source}) does not support forwarding " + "client-originated claims.") + + class ManagedIdentity(UserDict): """Feed an instance of this class to :class:`msal.ManagedIdentityClient` to acquire token for the specified managed identity. @@ -261,6 +267,7 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, + forwarded_client_claims: Optional[str] = None, ): """Acquire token for the managed identity. @@ -280,6 +287,21 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. + :param forwarded_client_claims: + Optional. + A string representation of a JSON object containing + *client-originated* claims to forward to the identity endpoint. + + Unlike ``claims_challenge`` (server-issued, which bypasses the cache), + tokens acquired with ``forwarded_client_claims`` **are cached**, and the cache + entry is keyed on the claims value. Send the *same* value on every + request that should share the cached token; different values produce + separate cache entries, so use stable, non-dynamic values to avoid + unbounded cache growth. + + Only the IMDS (Azure VM) managed identity source supports this + parameter; other sources raise an error. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -294,6 +316,20 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() + if forwarded_client_claims is not None: + if not isinstance(forwarded_client_claims, str): + raise ValueError( + "forwarded_client_claims must be a string, got {}".format( + type(forwarded_client_claims).__name__)) + _parse_claims_or_raise(forwarded_client_claims) # Fail fast on malformed JSON + # Reject unsupported sources before any cache read, so an unsupported + # source never returns a cached client-claims token. + _raise_if_claims_unsupported_source() + # Client-originated claims isolate the cache: a distinct claims value gets + # a distinct cache entry. (Server-issued claims_challenge, by contrast, + # bypasses the cache and is keyed normally.) + ext_cache_key = _compute_ext_cache_key( + {"client_claims": forwarded_client_claims}) if forwarded_client_claims else None if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( @@ -304,6 +340,7 @@ def acquire_token_for_client( environment=self.__instance, realm=self._tenant, home_account_id=None, + **({"ext_cache_key": ext_cache_key} if ext_cache_key else {}), ), ) for entry in matches: @@ -334,6 +371,7 @@ def acquire_token_for_client( access_token_to_refresh.encode("utf-8")).hexdigest() if access_token_to_refresh else None, client_capabilities=self._client_capabilities, + client_claims=forwarded_client_claims, ) if "access_token" in result: expires_in = result.get("expires_in", 3600) @@ -346,7 +384,7 @@ def acquire_token_for_client( self.__instance, self._tenant), response=result, params={}, - data={}, + data={"client_claims": forwarded_client_claims} if forwarded_client_claims else {}, )) if "refresh_in" in result: result["refresh_on"] = int(now + result["refresh_in"]) @@ -409,15 +447,42 @@ def get_managed_identity_source(): return DEFAULT_TO_VM +# Managed-identity sources that cannot forward client-originated claims. Keep in +# sync with the per-source guards inside _obtain_token (the backstop). Cloud Shell +# is intentionally absent: it falls through to the Azure VM / IMDS path, which +# does support claims. +_CLIENT_CLAIMS_UNSUPPORTED_SOURCES = { + SERVICE_FABRIC: "Service Fabric", + APP_SERVICE: "App Service", + MACHINE_LEARNING: "Machine Learning", + AZURE_ARC: "Azure Arc", +} + + +def _raise_if_claims_unsupported_source(): + """Fail fast -- before any cache read -- when the detected managed-identity + source cannot forward client-originated claims. ``_obtain_token`` enforces the + same rule per source as a backstop, but validating up front avoids a cache + lookup (and returning a cached token) for an unsupported source.""" + name = _CLIENT_CLAIMS_UNSUPPORTED_SOURCES.get(get_managed_identity_source()) + if name: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source=name)) + + def _obtain_token( http_client, managed_identity, resource, *, access_token_sha256_to_refresh: Optional[str] = None, client_capabilities: Optional[List[str]] = None, + client_claims: Optional[str] = None, ): if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ and "IDENTITY_SERVER_THUMBPRINT" in os.environ ): + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Service Fabric")) if managed_identity: logger.debug( "Ignoring managed_identity parameter. " @@ -434,6 +499,9 @@ def _obtain_token( client_capabilities=client_capabilities, ) if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ: + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="App Service")) return _obtain_token_on_app_service( http_client, os.environ["IDENTITY_ENDPOINT"], @@ -442,6 +510,9 @@ def _obtain_token( resource, ) if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ: + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Machine Learning")) # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py return _obtain_token_on_machine_learning( http_client, @@ -452,6 +523,9 @@ def _obtain_token( ) arc_endpoint = _get_arc_endpoint() if arc_endpoint: + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Azure Arc")) if ManagedIdentity.is_user_assigned(managed_identity): raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too "Invalid managed_identity parameter. " @@ -459,7 +533,8 @@ def _obtain_token( "See also " "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service") return _obtain_token_on_arc(http_client, arc_endpoint, resource) - return _obtain_token_on_azure_vm(http_client, managed_identity, resource) + return _obtain_token_on_azure_vm( + http_client, managed_identity, resource, client_claims=client_claims) def _adjust_param(params, managed_identity, types_mapping=None): @@ -469,7 +544,8 @@ def _adjust_param(params, managed_identity, types_mapping=None): if id_name: params[id_name] = managed_identity[ManagedIdentity.ID] -def _obtain_token_on_azure_vm(http_client, managed_identity, resource): + +def _obtain_token_on_azure_vm(http_client, managed_identity, resource, client_claims=None): # Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http logger.debug("Obtaining token via managed identity on Azure VM") params = { @@ -477,6 +553,10 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource): "resource": resource, } _adjust_param(params, managed_identity) + if client_claims: + # Forward client-originated claims as-is; IMDS decides which keys it + # accepts (no client-side allow-list, matching the other MSALs). + params["claims"] = client_claims # http_client.get url-encodes query params resp = http_client.get( os.getenv( "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254" diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 4590d52d..5b38519c 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -253,6 +253,14 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(data or {}) # So the content in data param prevails _data = {k: v for k, v in _data.items() if v} # Clean up None values + # "client_claims" is a cache-key-only pseudo-parameter: callers merge its + # value into the standard "claims" body parameter upstream, and it is kept + # in the request data solely so it contributes to the extended cache key. + # It must not be sent on the wire. Popping it here (from this method's own + # local copy) keeps the wire body clean while the caller's data dict — used + # for the cache-add event — still carries it. + _data.pop("client_claims", None) + if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) diff --git a/msal/token_cache.py b/msal/token_cache.py index 0ca250df..c6b4a70a 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -83,8 +83,15 @@ def _compute_ext_cache_key(data): Returns an empty string when *data* has no hashable fields. - The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator): - sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded. + The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator, + post-collision-fix): length-prefixed key/value pairs (sorted by key) are + concatenated and SHA256 hashed, then base64url encoded. The length prefixes + make the encoding injective, so two distinct component sets can never collide + onto the same cache key. (MSAL .NET's ``ComputeAccessTokenExtCacheKey`` still + uses an unprefixed concatenation, so the hash is intentionally not + byte-identical to current .NET; the cache *key format* still matches both. + Caches are not shared across languages, so this only affects within-process + isolation, where injectivity is what matters.) """ if not data: return "" @@ -94,14 +101,76 @@ def _compute_ext_cache_key(data): } if not cache_components: return "" - # Sort keys for consistent hashing (matches Go implementation) + # Concatenate length-prefixed key/value pairs so component boundaries are + # unambiguous (matches Go's CacheExtKeyGenerator). A plain key+value + # concatenation with no separators can collide when one value happens to + # contain another component's key or value -- and client_claims is arbitrary + # caller-supplied JSON that may embed e.g. "fmi_path" at a boundary -- mapping + # two distinct component sets onto the same hash and returning the wrong + # cached token. Length prefixes make the encoding injective. key_str = "".join( - k + cache_components[k] for k in sorted(cache_components.keys()) + "{}:{}{}:{}".format(len(k), k, len(v), v) + for k, v in sorted(cache_components.items()) ) hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest() return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() +def _parse_claims_or_raise(claims): + """Parse a claims JSON string into a dict, or raise a friendly ``ValueError``. + + The raw claims value is never included in the error message because it may + contain sensitive data. Mirrors MSAL .NET's ``ClaimsHelper.ParseClaimsOrThrow``. + """ + try: + parsed = json.loads(claims) + except (ValueError, TypeError) as ex: + # json.JSONDecodeError (malformed JSON) is a subclass of ValueError; + # TypeError is raised when *claims* is not a str/bytes/bytearray. Both + # are surfaced as the same friendly ValueError so every caller behaves + # consistently regardless of the bad input's type. + raise ValueError( + "The claims value is not valid JSON. " + "See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter." + ) from ex + if not isinstance(parsed, dict): + # A valid JSON array, scalar, or the literal "null" is not a claims object. + raise ValueError( + "The claims value is not a valid JSON object. " + "See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter.") + return parsed + + +def _deep_merge_dict(base, overlay): + """Recursively merge ``overlay`` into ``base``, returning a new dict. + + Nested dicts are merged; for any other value type, ``overlay`` wins. + """ + result = dict(base) + for key, value in overlay.items(): + if (key in result + and isinstance(result[key], dict) and isinstance(value, dict)): + result[key] = _deep_merge_dict(result[key], value) + else: + result[key] = value + return result + + +def _merge_claims(claims_a, claims_b): + """Merge two claims JSON strings into a single JSON string. + + If either side is empty/None, the other is returned as-is. Mirrors MSAL + .NET's ``ClaimsHelper.MergeClaimsObjects``. + """ + if not claims_a: + return claims_b + if not claims_b: + return claims_a + merged = _deep_merge_dict( + _parse_claims_or_raise(claims_a), _parse_claims_or_raise(claims_b)) + return json.dumps(merged) + + def is_subdict_of(small, big): return dict(big, **small) == big diff --git a/tests/test_application.py b/tests/test_application.py index 31f77a71..a5d4ca8b 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -966,6 +966,449 @@ def test_fmi_token_does_not_interfere_with_non_fmi_token(self): "Non-FMI call should not return FMI-cached token") +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenForClientWithClientClaims(unittest.TestCase): + """acquire_token_for_client(forwarded_client_claims=...) forwards client-originated claims + via the OAuth "claims" body parameter, caches the result, and keys the cache + entry on the claims value.""" + + _CLIENT_CLAIMS = '{"access_token": {"xms_az_nwperimid": {"essential": true}}}' + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", + **kwargs) + + def test_client_claims_rejects_non_string_types(self): + app = self._build_app() + for bad_value in [123, True, ["claims"], {"a": "b"}, b"bytes"]: + with self.assertRaises(ValueError, + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], forwarded_client_claims=bad_value) + + def test_client_claims_rejects_invalid_json(self): + app = self._build_app() + for bad_value in ["not json", "[1, 2]", "null", "123"]: + with self.assertRaises(ValueError, + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], forwarded_client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + result = app.acquire_token_for_client( + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) + self.assertIn("access_token", result) + # The client claims are forwarded via the standard OAuth "claims" parameter + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + # The cache-key-only pseudo-parameter must NOT leak onto the wire + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_client_claims_merged_with_client_capabilities(self): + app = self._build_app(client_capabilities=["CP1"]) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_for_client( + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) + merged = json.loads(captured_data["claims"]) + self.assertEqual( + { + "xms_cc": {"values": ["CP1"]}, + "xms_az_nwperimid": {"essential": True}, + }, + merged["access_token"], + "client_claims must merge with capability-derived claims") + self.assertNotIn("client_claims", captured_data) + + def test_forwarded_client_claims_merged_with_claims_challenge(self): + # All three claim sources -- the server-issued claims_challenge, client + # capabilities, and forwarded_client_claims -- must combine into the + # single OAuth "claims" parameter that is sent on the wire. + app = self._build_app(client_capabilities=["CP1"]) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + challenge = '{"access_token": {"nbf": {"essential": true, "value": "1601000000"}}}' + app.acquire_token_for_client( + ["scope"], claims_challenge=challenge, + forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) + merged = json.loads(captured_data["claims"]) + self.assertEqual( + { + "nbf": {"essential": True, "value": "1601000000"}, + "xms_cc": {"values": ["CP1"]}, + "xms_az_nwperimid": {"essential": True}, + }, + merged["access_token"], + "claims_challenge, capabilities, and forwarded_client_claims must all merge") + self.assertNotIn("client_claims", captured_data) + + def test_forwarded_client_claims_win_on_leaf_conflict_with_challenge(self): + # If the server-issued claims_challenge and forwarded_client_claims set + # the SAME claim, the client-originated value wins (it is merged in last), + # while disjoint claims from the challenge are preserved. Documents the + # conflict-resolution behavior the other MSAL reviewers asked about. + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + challenge = ('{"access_token": {"acrs": {"values": ["server"]},' + ' "nbf": {"essential": true}}}') + client = '{"access_token": {"acrs": {"values": ["client"]}}}' + app.acquire_token_for_client( + ["scope"], claims_challenge=challenge, + forwarded_client_claims=client, post=mock_post) + merged = json.loads(captured_data["claims"])["access_token"] + self.assertEqual( + {"values": ["client"]}, merged["acrs"], + "forwarded_client_claims must win a direct leaf conflict") + self.assertEqual( + {"essential": True}, merged["nbf"], + "Disjoint claims from the challenge must be preserved") + app = self._build_app() + call_count = [0] + + def mock_post(url, headers=None, data=None, *args, **kwargs): + call_count[0] += 1 + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + result1 = app.acquire_token_for_client( + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) + self.assertEqual(result1[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP) + result2 = app.acquire_token_for_client( + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) + self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, + "Same client_claims should return token from cache") + self.assertEqual(1, call_count[0], "Second call should not hit the IdP") + + def test_different_client_claims_are_cached_separately(self): + app = self._build_app() + + def mock_post_factory(token_value): + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": token_value, "expires_in": 3600})) + return mock_post + + claims_a = '{"access_token": {"xms_az_nwperimid": {"values": ["A"]}}}' + claims_b = '{"access_token": {"xms_az_nwperimid": {"values": ["B"]}}}' + + result_a = app.acquire_token_for_client( + ["scope"], forwarded_client_claims=claims_a, post=mock_post_factory("AT_A")) + self.assertEqual("AT_A", result_a["access_token"]) + + result_b = app.acquire_token_for_client( + ["scope"], forwarded_client_claims=claims_b, post=mock_post_factory("AT_B")) + self.assertEqual("AT_B", result_b["access_token"]) + self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "Different client_claims must NOT share a cache entry") + + result_a2 = app.acquire_token_for_client( + ["scope"], forwarded_client_claims=claims_a, post=mock_post_factory("unused")) + self.assertEqual("AT_A", result_a2["access_token"]) + self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + + def test_client_claims_token_does_not_interfere_with_plain_token(self): + app = self._build_app() + app.acquire_token_for_client( + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "claims_AT", "expires_in": 3600}))) + result = app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "plain_AT", "expires_in": 3600}))) + self.assertEqual("plain_AT", result["access_token"]) + self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "A plain request must not return a client_claims-cached token") + + +def _build_user_token_response( + access_token="user_at", uid="user_oid", utid="my_tenant", + client_id="client_id", refresh_token=None): + """A mock user-token response (AT + id_token + client_info), optionally with + a refresh token, so that an account is created and silent retrieval works.""" + extra = {"id_token": build_id_token( + aud=client_id, oid=uid, tid=utid, preferred_username="user@contoso.com")} + if refresh_token: + extra["refresh_token"] = refresh_token + return json.dumps(build_response( + uid=uid, utid=utid, access_token=access_token, **extra)) + + +_CLIENT_CLAIMS = '{"access_token": {"xms_az_nwperimid": {"essential": true}}}' +_OTHER_CLIENT_CLAIMS = '{"access_token": {"xms_az_nwperimid": {"values": ["other"]}}}' + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenOnBehalfOfWithClientClaims(unittest.TestCase): + """acquire_token_on_behalf_of(forwarded_client_claims=...) forwards client-originated + claims via the OAuth "claims" parameter and isolates the cached token.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def test_client_claims_rejects_invalid_values(self): + app = self._build_app() + for bad_value in [123, True, ["claims"], b"bytes", "not json", "null", "[1,2]"]: + with self.assertRaises(ValueError, + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_on_behalf_of( + "assertion", ["s"], forwarded_client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_on_behalf_of( + "assertion", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_client_claims_merged_with_client_capabilities(self): + app = self._build_app(client_capabilities=["CP1"]) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_on_behalf_of( + "assertion", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertEqual( + { + "xms_cc": {"values": ["CP1"]}, + "xms_az_nwperimid": {"essential": True}, + }, + json.loads(captured_data["claims"])["access_token"]) + + def test_cached_token_is_isolated_by_client_claims(self): + app = self._build_app() + app.acquire_token_on_behalf_of( + "assertion", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response(access_token="obo_at"))) + accounts = app.get_accounts() + self.assertTrue(accounts, "OBO response should create an account") + hit = app.acquire_token_silent( + ["s"], accounts[0], forwarded_client_claims=_CLIENT_CLAIMS) + self.assertIsNotNone(hit) + self.assertEqual("obo_at", hit["access_token"]) + self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + self.assertIsNone( + app.acquire_token_silent( + ["s"], accounts[0], forwarded_client_claims=_OTHER_CLIENT_CLAIMS), + "Different client_claims must not read the cached token") + self.assertIsNone( + app.acquire_token_silent(["s"], accounts[0]), + "A plain silent call must not read a client_claims token") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenByAuthorizationCodeWithClientClaims(unittest.TestCase): + """acquire_token_by_authorization_code(forwarded_client_claims=...) forwards + client-originated claims and isolates the cached token.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def test_client_claims_rejects_invalid_values(self): + app = self._build_app() + for bad_value in [123, ["claims"], b"bytes", "not json", "null"]: + with self.assertRaises(ValueError, + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_by_authorization_code( + "code", ["s"], forwarded_client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_by_authorization_code( + "code", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_cached_token_is_isolated_by_client_claims(self): + app = self._build_app() + app.acquire_token_by_authorization_code( + "code", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response(access_token="authcode_at"))) + accounts = app.get_accounts() + self.assertTrue(accounts) + hit = app.acquire_token_silent( + ["s"], accounts[0], forwarded_client_claims=_CLIENT_CLAIMS) + self.assertIsNotNone(hit) + self.assertEqual("authcode_at", hit["access_token"]) + self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + self.assertIsNone( + app.acquire_token_silent(["s"], accounts[0]), + "A plain silent call must not read a client_claims token") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestUserFicWithClientClaims(unittest.TestCase): + """acquire_token_by_user_federated_identity_credential(forwarded_client_claims=...) + forwards client-originated claims and isolates the cached token.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "agent_app_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def test_client_claims_rejects_invalid_values(self): + app = self._build_app() + for bad_value in [123, ["claims"], b"bytes", "not json", "null"]: + with self.assertRaises(ValueError, + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_by_user_federated_identity_credential( + ["s"], assertion="t2", username="user@contoso.com", + forwarded_client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_token_response( + client_id="agent_app_id")) + + app.acquire_token_by_user_federated_identity_credential( + ["s"], assertion="t2", username="user@contoso.com", + forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_cached_token_is_isolated_by_client_claims(self): + app = self._build_app() + app.acquire_token_by_user_federated_identity_credential( + ["s"], assertion="t2", username="user@contoso.com", + forwarded_client_claims=_CLIENT_CLAIMS, + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response( + access_token="fic_at", client_id="agent_app_id"))) + accounts = app.get_accounts() + self.assertTrue(accounts) + hit = app.acquire_token_silent( + ["s"], accounts[0], forwarded_client_claims=_CLIENT_CLAIMS) + self.assertIsNotNone(hit) + self.assertEqual("fic_at", hit["access_token"]) + self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + self.assertIsNone( + app.acquire_token_silent(["s"], accounts[0]), + "A plain silent call must not read a client_claims token") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenSilentWithClientClaims(unittest.TestCase): + """acquire_token_silent(forwarded_client_claims=...) isolates cache reads and merges + the claims into the refresh-token request sent on the wire.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def _seed_account_with_rt(self, app): + app.acquire_token_on_behalf_of( + "assertion", ["s"], + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response( + access_token="seed_at", refresh_token="seed_rt"))) + accounts = app.get_accounts() + self.assertTrue(accounts) + return accounts[0] + + def test_client_claims_merged_into_refresh_request(self): + app = self._build_app() + account = self._seed_account_with_rt(app) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_token_response( + access_token="refreshed_at", refresh_token="seed_rt")) + + result = app.acquire_token_silent( + ["s"], account, force_refresh=True, + forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIsNotNone(result) + self.assertEqual("refresh_token", captured_data.get("grant_type")) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not leak onto the refresh request body") + + def test_both_silent_entry_points_validate_client_claims(self): + app = self._build_app() + account = self._seed_account_with_rt(app) + for bad_value in [123, ["claims"], "not json", "null"]: + with self.assertRaises(ValueError): + app.acquire_token_silent( + ["s"], account, forwarded_client_claims=bad_value) + with self.assertRaises(ValueError): + app.acquire_token_silent_with_error( + ["s"], account, forwarded_client_claims=bad_value) + + @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestRemoveTokensForClient(unittest.TestCase): def test_remove_tokens_for_client_should_remove_client_tokens_only(self): diff --git a/tests/test_mi.py b/tests/test_mi.py index fd3834c8..f7082aa2 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -262,6 +262,99 @@ def test_vm_resource_id_parameter_should_be_msi_res_id(self): uuid.UUID(corr_id) +class VmClientClaimsTestCase(ClientTestCase): + """client_claims is only supported on the IMDS (Azure VM) source, where it is + forwarded as the "claims" query parameter and keys the cache entry.""" + + _CLAIMS = '{"xms_az_nwperimid": {"essential": true}}' + + def _mock_get(self, token="AT"): + return patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "%s", "expires_in": "3600", "resource": "R"}' % token, + )) + + def test_client_claims_sent_as_claims_query_param(self): + with self._mock_get() as mock_get: + result = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=self._CLAIMS) + self.assertIn("access_token", result) + self.assertEqual( + self._CLAIMS, mock_get.call_args.kwargs["params"].get("claims"), + "client_claims should be forwarded as the 'claims' query parameter") + + def test_no_claims_param_when_client_claims_absent(self): + with self._mock_get() as mock_get: + self.app.acquire_token_for_client(resource="R") + self.assertNotIn("claims", mock_get.call_args.kwargs["params"]) + + def test_non_nwperimid_claim_is_forwarded_not_rejected(self): + # MSAL no longer enforces a client-side allow-list (matching the other + # MSALs); any JSON-object claims value is forwarded as-is and IMDS decides + # which keys it accepts. + other = '{"some_other_claim": {"essential": true}}' + with self._mock_get() as mock_get: + result = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=other) + self.assertIn("access_token", result) + self.assertEqual( + other, mock_get.call_args.kwargs["params"].get("claims"), + "Non-nwperimid claims must be forwarded, not rejected") + + def test_invalid_json_claims_raises(self): + for bad in ["not json", "[1, 2]", "null"]: + with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): + self.app.acquire_token_for_client(resource="R", forwarded_client_claims=bad) + + def test_non_string_client_claims_raises_value_error(self): + # A non-str client_claims (int, bytes, dict, ...) must raise a friendly + # ValueError rather than leaking a raw TypeError from json.loads or + # hashing inconsistently into the extended cache key. + for bad in [123, b'{"xms_az_nwperimid": {}}', {"xms_az_nwperimid": {}}]: + with self.assertRaises( + ValueError, msg="{!r} should raise ValueError".format(bad)): + self.app.acquire_token_for_client(resource="R", forwarded_client_claims=bad) + + def test_same_client_claims_hits_cache(self): + with self._mock_get("AT1") as mock_get: + r1 = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=self._CLAIMS) + self.assertEqual("identity_provider", r1["token_source"]) + r2 = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=self._CLAIMS) + self.assertEqual("cache", r2["token_source"], "Should hit cache") + self.assertEqual(1, mock_get.call_count, "Second call must not hit IMDS") + + def test_different_client_claims_are_cached_separately(self): + claims_a = '{"xms_az_nwperimid": {"values": ["A"]}}' + claims_b = '{"xms_az_nwperimid": {"values": ["B"]}}' + with self._mock_get("AT_A"): + ra = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=claims_a) + self.assertEqual("AT_A", ra["access_token"]) + with self._mock_get("AT_B"): + rb = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=claims_b) + self.assertEqual("AT_B", rb["access_token"]) + self.assertEqual("identity_provider", rb["token_source"], + "Different client_claims must NOT share a cache entry") + with self._mock_get("unused"): + ra2 = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=claims_a) + self.assertEqual("AT_A", ra2["access_token"]) + self.assertEqual("cache", ra2["token_source"]) + + def test_plain_request_does_not_return_client_claims_token(self): + with self._mock_get("claims_AT"): + self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=self._CLAIMS) + with self._mock_get("plain_AT"): + result = self.app.acquire_token_for_client(resource="R") + self.assertEqual("plain_AT", result["access_token"]) + self.assertEqual("identity_provider", result["token_source"], + "A plain request must not return a client_claims-cached token") + + @patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) class AppServiceTestCase(ClientTestCase): @@ -302,6 +395,11 @@ def test_app_service_resource_id_parameter_should_be_mi_res_id(self): headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'}, ) + def test_client_claims_not_supported_on_app_service(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') + @patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) class MachineLearningTestCase(ClientTestCase): @@ -327,6 +425,11 @@ def test_machine_learning_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_client_claims_not_supported_on_machine_learning(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') + @patch.dict(os.environ, { "IDENTITY_ENDPOINT": "http://localhost", @@ -398,6 +501,11 @@ def test_sf_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_client_claims_not_supported_on_service_fabric(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') + @patch.dict(os.environ, { "IDENTITY_ENDPOINT": "http://localhost/token", @@ -451,6 +559,11 @@ def test_arc_error_should_be_normalized(self, mocked_stat): if sys.platform in _supported_arc_platforms_and_their_prefixes: self.fail("Should not raise ArcPlatformNotSupportedError") + def test_client_claims_not_supported_on_arc(self, mocked_stat): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') + class GetManagedIdentitySourceTestCase(unittest.TestCase): diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index d7dfe8de..bd4ca840 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -4,7 +4,10 @@ import time import warnings -from msal.token_cache import TokenCache, SerializableTokenCache, _compute_ext_cache_key +from msal.token_cache import ( + TokenCache, SerializableTokenCache, _compute_ext_cache_key, + _parse_claims_or_raise, _merge_claims, +) from tests import unittest @@ -362,6 +365,124 @@ def test_non_excluded_fields_are_included_in_hash(self): self.assertNotEqual(h1, h2, "Non-excluded fields should change the hash") +class TestClientClaimsCacheKey(unittest.TestCase): + """client_claims must drive the extended cache key; server-issued claims must not.""" + + def test_client_claims_produces_non_empty_hash(self): + result = _compute_ext_cache_key({"client_claims": '{"a": 1}'}) + self.assertNotEqual("", result) + self.assertIsInstance(result, str) + + def test_claims_is_excluded_but_client_claims_is_not(self): + # Server-issued "claims" (claims_challenge) must not affect the key, because + # it bypasses the cache. Client-originated "client_claims" must affect it. + self.assertEqual("", _compute_ext_cache_key({"claims": '{"a": 1}'})) + self.assertNotEqual("", _compute_ext_cache_key({"client_claims": '{"a": 1}'})) + + def test_same_client_claims_produce_same_hash(self): + self.assertEqual( + _compute_ext_cache_key({"client_claims": '{"a": 1}'}), + _compute_ext_cache_key({"client_claims": '{"a": 1}'})) + + def test_different_client_claims_produce_different_hashes(self): + self.assertNotEqual( + _compute_ext_cache_key({"client_claims": '{"a": 1}'}), + _compute_ext_cache_key({"client_claims": '{"a": 2}'})) + + def test_empty_client_claims_value_is_ignored(self): + self.assertEqual("", _compute_ext_cache_key({"client_claims": ""})) + + def test_length_prefixed_encoding_avoids_boundary_collision(self): + # Mirrors Go's TestCacheKeyComponentHashNoBoundaryCollision. With a plain + # key+value concatenation (no separators) these two distinct component + # sets both render to "axbYbZ" (sorted keys "a","b") and would collide, + # returning the wrong cached token. The length-prefixed encoding must keep + # them distinct. This matters because client_claims is arbitrary caller + # JSON that can embed another component's key (e.g. "fmi_path"). + h1 = _compute_ext_cache_key({"a": "xbY", "b": "Z"}) + h2 = _compute_ext_cache_key({"a": "x", "b": "YbZ"}) + self.assertNotEqual( + h1, h2, + "distinct cache key components must not produce the same hash") + + def test_client_claims_and_fmi_path_do_not_collide_at_boundary(self): + # Realistic surface: client_claims and fmi_path co-occur in + # acquire_token_for_client. A claims value that happens to contain the + # other component's key+value at a boundary must not collide. + h1 = _compute_ext_cache_key( + {"client_claims": "Xfmi_pathY", "fmi_path": "Z"}) + h2 = _compute_ext_cache_key( + {"client_claims": "X", "fmi_path": "Yfmi_pathZ"}) + self.assertNotEqual(h1, h2) + + +class TestClaimsHelpers(unittest.TestCase): + """Tests for the shared _parse_claims_or_raise / _merge_claims helpers.""" + + def test_parse_valid_object(self): + self.assertEqual({"a": 1}, _parse_claims_or_raise('{"a": 1}')) + + def test_parse_rejects_non_object_and_malformed(self): + for bad in ["not json", "[1, 2]", "null", "123", '"a string"', "true"]: + with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): + _parse_claims_or_raise(bad) + + def test_parse_rejects_non_string_types(self): + # Non-str/bytes inputs make json.loads raise TypeError; the helper must + # surface the same friendly ValueError so callers behave consistently + # regardless of the bad input's type. + for bad in [123, None, 1.5, True, ["a"], {"a": 1}]: + with self.assertRaises( + ValueError, msg="{!r} should raise ValueError".format(bad)): + _parse_claims_or_raise(bad) + + def test_parse_error_does_not_leak_raw_claims(self): + # A malformed payload that contains a secret-looking value + secret = '{"super": "secret-value-123"' # missing closing brace + with self.assertRaises(ValueError) as ctx: + _parse_claims_or_raise(secret) + self.assertNotIn("secret-value-123", str(ctx.exception), + "Error message must never echo the raw claims content") + + def test_merge_returns_other_when_one_side_is_empty(self): + self.assertEqual({"a": 1}, json.loads(_merge_claims(None, '{"a": 1}'))) + self.assertEqual({"a": 1}, json.loads(_merge_claims('{"a": 1}', None))) + self.assertEqual({"a": 1}, json.loads(_merge_claims("", '{"a": 1}'))) + self.assertEqual({"a": 1}, json.loads(_merge_claims('{"a": 1}', ""))) + + def test_merge_of_two_empties_is_falsy(self): + self.assertFalse(_merge_claims(None, None)) + self.assertFalse(_merge_claims("", "")) + + def test_merge_deep_merges_objects(self): + merged = json.loads(_merge_claims( + '{"access_token": {"xms_cc": {"values": ["cp1"]}}}', + '{"access_token": {"xms_az_nwperimid": {"essential": true}}}')) + self.assertEqual( + {"xms_cc": {"values": ["cp1"]}, "xms_az_nwperimid": {"essential": True}}, + merged["access_token"]) + + def test_merge_leaf_conflict_second_arg_wins(self): + # When both sides set the SAME leaf to different values, the second + # argument (client-originated claims) wins. _acquire_token_for_client + # passes server-issued claims first and forwarded_client_claims second, + # so the caller's value takes precedence on a direct conflict. + merged = json.loads(_merge_claims( + '{"access_token": {"acrs": {"values": ["server"]}}}', + '{"access_token": {"acrs": {"values": ["client"]}}}')) + self.assertEqual({"values": ["client"]}, merged["access_token"]["acrs"]) + + def test_merge_nested_conflict_preserves_disjoint_siblings(self): + # A conflict on one nested key must not drop the other side's disjoint + # sibling keys at the same level. + merged = json.loads(_merge_claims( + '{"access_token": {"a": {"values": ["server"]}, "keep_server": 1}}', + '{"access_token": {"a": {"values": ["client"]}, "keep_client": 2}}')) + self.assertEqual( + {"a": {"values": ["client"]}, "keep_server": 1, "keep_client": 2}, + merged["access_token"]) + + class TestExtCacheKeyIsolation(unittest.TestCase): """Tests that ext_cache_key provides proper cache isolation in TokenCache.""" @@ -444,40 +565,48 @@ def test_non_fmi_tokens_not_affected_by_fmi_cache(self): class TestCrossMsalCacheKeyCompatibility(unittest.TestCase): - """Verify that _compute_ext_cache_key produces hashes identical to MSAL Go - (CacheExtKeyGenerator) and MSAL .NET (CoreHelpers.ComputeAccessTokenExtCacheKey). + """Verify that _compute_ext_cache_key matches MSAL Go's CacheExtKeyGenerator + (post collision-fix, AzureAD/microsoft-authentication-library-for-go#629). - All three libraries use the same algorithm: + The algorithm: 1. Sort key-value pairs alphabetically by key (ordinal / case-sensitive) - 2. Concatenate them: "key1value1key2value2…" + 2. Concatenate length-prefixed pairs ("{len(k)}:{k}{len(v)}:{v}" per pair; + e.g. {"key1": "value1"} -> "4:key16:value1"). The length prefixes make + the encoding injective -- see TestClientClaimsCacheKey for the collision + guard. 3. SHA-256 hash 4. Base64url encode (no padding), lowercased - The expected hashes below are copied from: - - MSAL Go: authority_ext_cachekey_test.go (TestAppKeyWithCacheKeyComponent) - - MSAL .NET: CacheKeyExtensionTests.cs (RunHappyPathTest, CacheExtEnsurePopKeysFunctionAsync) + The expected hashes below are copied from MSAL Go's + authority_ext_cachekey_test.go (TestAppKeyWithCacheKeyComponent). + + NOTE: MSAL .NET's ComputeAccessTokenExtCacheKey still uses an *unprefixed* + concatenation, so these hashes are intentionally NOT byte-identical to current + .NET. The cache *key format* (the 'atext' segment layout, asserted below) still + matches both Go and .NET; only the trailing hash bytes differ. Caches are not + shared across languages, so within-process injectivity -- not cross-language + byte-parity -- is what matters for correctness. """ - def test_two_params_hash_matches_go_and_dotnet(self): - """Go/dotnet expected: bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi""" + def test_two_params_hash_matches_go(self): + """Go expected: latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike""" result = _compute_ext_cache_key({"key1": "value1", "key2": "value2"}) - self.assertEqual("bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", result) + self.assertEqual("latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike", result) - def test_two_different_params_hash_matches_go_and_dotnet(self): - """Go/dotnet expected: 3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u""" + def test_two_different_params_hash_matches_go(self): + """Go expected: jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq""" result = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) - self.assertEqual("3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u", result) + self.assertEqual("jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq", result) - def test_five_params_hash_matches_go_and_dotnet(self): - """Go/dotnet expected (full hash): rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e - Go test uses substring match 'gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e'.""" + def test_five_params_hash_matches_go(self): + """Go expected: prrdp31y37ufw3lo7hly0oimjjvg_34m9ji30ocu4tw""" result = _compute_ext_cache_key({ "key3": "value3", "key4": "value4", "key5": "value5", "key6": "value6", "key7": "value7", }) - self.assertEqual("rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e", result) + self.assertEqual("prrdp31y37ufw3lo7hly0oimjjvg_34m9ji30ocu4tw", result) - def test_order_independence_matches_go_and_dotnet(self): + def test_order_independence_matches_go(self): """Same keys in different insertion order must produce the same hash (mirrors TestCacheKeyComponentHashConsistency in Go).""" h1 = _compute_ext_cache_key({"key3": "value3", "key4": "value4", @@ -498,9 +627,9 @@ def test_at_cache_key_uses_atext_credential_type(self): key = key_maker( home_account_id="hid", environment="env", client_id="cid", realm="realm", target="scope", - ext_cache_key="bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi") + ext_cache_key="latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike") self.assertEqual( - "hid-env-atext-cid-realm-scope-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", + "hid-env-atext-cid-realm-scope-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike", key) def test_at_cache_key_without_ext_uses_accesstoken(self): @@ -512,9 +641,10 @@ def test_at_cache_key_without_ext_uses_accesstoken(self): realm="realm", target="scope") self.assertEqual("hid-env-accesstoken-cid-realm-scope", key) - def test_dotnet_style_full_at_cache_key(self): - """Reproduce the exact cache key from MSAL .NET CacheKeyExtensionTests: - expectedCacheKey1 = '-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi' + def test_atext_full_at_cache_key_format(self): + """The AT cache key *format* matches MSAL .NET's CacheKeyExtensionTests + layout ('-{env}-atext-{clientId}-{tenant}-{scopes}-{hash}'); only the + trailing hash now follows Go's length-prefixed encoding (see class note). """ cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] @@ -526,11 +656,11 @@ def test_dotnet_style_full_at_cache_key(self): realm="common", target="r1/scope1 r1/scope2", ext_cache_key=ext_hash) - expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike" self.assertEqual(expected, key) - def test_dotnet_style_second_cache_key(self): - """Reproduce CacheKeyExtensionTests expectedCacheKey2.""" + def test_atext_second_full_at_cache_key_format(self): + """Second key-format vector (mirrors CacheKeyExtensionTests expectedCacheKey2).""" cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] ext_hash = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) @@ -541,12 +671,12 @@ def test_dotnet_style_second_cache_key(self): realm="common", target="r1/scope1 r1/scope2", ext_cache_key=ext_hash) - expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u" + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq" self.assertEqual(expected, key) def test_go_style_at_cache_key(self): - """Reproduce the Go AccessToken.Key() format: - Go test: 'testhid-env-atext-clientid-realm-user.read-{hash}' + """Reproduce the Go AccessToken.Key() format with Go's post-#629 hash: + 'testhid-env-atext-clientid-realm-user.read-{hash}'. """ cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] @@ -558,5 +688,5 @@ def test_go_style_at_cache_key(self): realm="realm", target="user.read", ext_cache_key=ext_hash) - expected = "testhid-env-atext-clientid-realm-user.read-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" + expected = "testhid-env-atext-clientid-realm-user.read-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike" self.assertEqual(expected, key)