From 1e81544b626c5e2ddb05238f9103f6bb77182a93 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 26 Jun 2026 14:23:32 -0400 Subject: [PATCH 1/9] Add client_claims support to confidential client and managed identity flows Port of msal-dotnet PR 5999 (WithClaimsFromClient). Adds a `client_claims` keyword argument to `ConfidentialClientApplication.acquire_token_for_client` and `ManagedIdentityClient.acquire_token_for_client` for forwarding client-originated claims (e.g. the network security perimeter `xms_az_nwperimid` claim) to ESTS/IMDS. Unlike `claims_challenge` (server-issued, bypasses the cache), `client_claims` tokens are cached and the cache entry is keyed on the claims value, reusing the existing `_compute_ext_cache_key` mechanism (the `fmi_path` precedent). - token_cache: add `_parse_claims_or_raise`, `_deep_merge_dict`, `_merge_claims` helpers; `claims` stays excluded from the ext cache key while `client_claims` participates in it. - oauth2: strip the cache-key-only `client_claims` pseudo-parameter from the wire body while preserving it for the cache-add event. - application: validate `client_claims`, merge it into the OAuth `claims` body parameter, and isolate the cache by claims value. - managed_identity: support `client_claims` on the IMDS (Azure VM) source only (sent as the `claims` query parameter); other sources raise; MSIv1 restricts the claims JSON to only the `xms_az_nwperimid` key. Adds unit tests covering cache isolation, wire shape, claim merging, source restrictions, and MSIv1 validation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/application.py | 43 +++++++++++-- msal/managed_identity.py | 77 ++++++++++++++++++++-- msal/oauth2cli/oauth2.py | 8 +++ msal/token_cache.py | 51 +++++++++++++++ tests/test_application.py | 132 ++++++++++++++++++++++++++++++++++++++ tests/test_mi.py | 97 ++++++++++++++++++++++++++++ tests/test_token_cache.py | 71 +++++++++++++++++++- 7 files changed, 468 insertions(+), 11 deletions(-) diff --git a/msal/application.py b/msal/application.py index 3fc69461..e680e3d2 100644 --- a/msal/application.py +++ b/msal/application.py @@ -20,7 +20,7 @@ from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * -from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key +from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key, _parse_claims_or_raise, _merge_claims import msal.telemetry from .region import _detect_region, _validate_region from .throttled_http_client import ThrottledHttpClient @@ -2494,7 +2494,7 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app except that ``allow_broker`` parameter shall remain ``None``. """ - def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, **kwargs): + def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, client_claims=None, **kwargs): """Acquires token for the current confidential client, not for an end user. Since MSAL Python 1.23, it will automatically look for token from cache, @@ -2518,6 +2518,18 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, scopes=["api://resource/.default"], fmi_path="SomeFmiPath/FmiCredentialPath", ) + :param str client_claims: + Optional. A JSON string containing *client-originated* claims to + include in the token request (for example a network security + perimeter ``xms_az_nwperimid`` claim). + + Unlike ``claims_challenge`` (which carries *server-issued* claims + challenges and bypasses the cache), tokens acquired with + ``client_claims`` **are cached**, and the cache entry is keyed on the + claims value. Different ``client_claims`` values produce separate + cache entries, so use stable, non-dynamic values to avoid unbounded + cache growth. The value is merged into the standard OAuth ``claims`` + request parameter sent on the wire. :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -2533,6 +2545,18 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, "fmi_path must be a string, got {}".format(type(fmi_path).__name__)) kwargs["data"] = kwargs.get("data", {}) kwargs["data"]["fmi_path"] = fmi_path + if client_claims is not None: + if not isinstance(client_claims, str): + raise ValueError( + "client_claims must be a string, got {}".format( + type(client_claims).__name__)) + _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON + # Carry it in the request data so it contributes to the extended + # cache key (different claims => separate cache entries). It is + # merged into the "claims" body parameter in _acquire_token_for_client + # and stripped from the wire body by the oauth2 layer. + kwargs["data"] = kwargs.get("data", {}) + kwargs["data"]["client_claims"] = client_claims return _clean_up(self._acquire_token_silent_with_error( scopes, None, claims_challenge=claims_challenge, **kwargs)) @@ -2552,13 +2576,20 @@ def _acquire_token_for_client( telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_FOR_CLIENT_ID, refresh_reason=refresh_reason) client = self._regional_client or self.client + request_data = kwargs.pop("data", {}) + claims = _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge) + # Client-originated claims (set via client_claims=) are merged into the + # same OAuth "claims" parameter and sent on the wire. The raw + # "client_claims" entry stays in request_data so it keys the cache; the + # oauth2 layer drops it from the actual request body. + client_claims = request_data.get("client_claims") + if client_claims: + claims = _merge_claims(claims, client_claims) response = client.obtain_token_for_client( scope=scopes, # This grant flow requires no scope decoration headers=telemetry_context.generate_headers(), - data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + data=dict(request_data, claims=claims), **kwargs) telemetry_context.update_telemetry(response) return response diff --git a/msal/managed_identity.py b/msal/managed_identity.py index b2fc446c..479423bb 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -12,7 +12,7 @@ from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ from typing import List, Optional, Union # Needed in Python 3.7 & 3.8 -from .token_cache import TokenCache +from .token_cache import TokenCache, _compute_ext_cache_key, _parse_claims_or_raise from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser from .cloudshell import _is_running_in_cloud_shell @@ -26,6 +26,14 @@ class ManagedIdentityError(ValueError): pass +_XMS_AZ_NWPERIMID = "xms_az_nwperimid" + +_CLIENT_CLAIMS_UNSUPPORTED_SOURCE = ( + "client_claims is only supported for the IMDS (Azure VM) managed identity " + "source. The detected source ({source}) does not support forwarding " + "client-originated claims.") + + class ManagedIdentity(UserDict): """Feed an instance of this class to :class:`msal.ManagedIdentityClient` to acquire token for the specified managed identity. @@ -261,6 +269,7 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, + client_claims: Optional[str] = None, ): """Acquire token for the managed identity. @@ -280,6 +289,22 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. + :param client_claims: + Optional. + A string representation of a JSON object containing + *client-originated* claims to forward to the identity endpoint + (for example a network security perimeter ``xms_az_nwperimid`` claim). + + Unlike ``claims_challenge`` (server-issued, which bypasses the cache), + tokens acquired with ``client_claims`` **are cached**, and the cache + entry is keyed on the claims value. Different ``client_claims`` values + produce separate cache entries, so use stable, non-dynamic values to + avoid unbounded cache growth. + + Only the IMDS (Azure VM) managed identity source supports this + parameter; other sources raise an error. On IMDS v1, the claims JSON + may contain only the ``xms_az_nwperimid`` key. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -294,6 +319,13 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() + if client_claims is not None: + _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON + # Client-originated claims isolate the cache: a distinct claims value gets + # a distinct cache entry. (Server-issued claims_challenge, by contrast, + # bypasses the cache and is keyed normally.) + ext_cache_key = _compute_ext_cache_key( + {"client_claims": client_claims}) if client_claims else None if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( @@ -304,6 +336,7 @@ def acquire_token_for_client( environment=self.__instance, realm=self._tenant, home_account_id=None, + **({"ext_cache_key": ext_cache_key} if ext_cache_key else {}), ), ) for entry in matches: @@ -334,6 +367,7 @@ def acquire_token_for_client( access_token_to_refresh.encode("utf-8")).hexdigest() if access_token_to_refresh else None, client_capabilities=self._client_capabilities, + client_claims=client_claims, ) if "access_token" in result: expires_in = result.get("expires_in", 3600) @@ -346,7 +380,7 @@ def acquire_token_for_client( self.__instance, self._tenant), response=result, params={}, - data={}, + data={"client_claims": client_claims} if client_claims else {}, )) if "refresh_in" in result: result["refresh_on"] = int(now + result["refresh_in"]) @@ -414,10 +448,14 @@ def _obtain_token( *, access_token_sha256_to_refresh: Optional[str] = None, client_capabilities: Optional[List[str]] = None, + client_claims: Optional[str] = None, ): if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ and "IDENTITY_SERVER_THUMBPRINT" in os.environ ): + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Service Fabric")) if managed_identity: logger.debug( "Ignoring managed_identity parameter. " @@ -434,6 +472,9 @@ def _obtain_token( client_capabilities=client_capabilities, ) if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ: + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="App Service")) return _obtain_token_on_app_service( http_client, os.environ["IDENTITY_ENDPOINT"], @@ -442,6 +483,9 @@ def _obtain_token( resource, ) if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ: + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Machine Learning")) # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py return _obtain_token_on_machine_learning( http_client, @@ -452,6 +496,9 @@ def _obtain_token( ) arc_endpoint = _get_arc_endpoint() if arc_endpoint: + if client_claims: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Azure Arc")) if ManagedIdentity.is_user_assigned(managed_identity): raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too "Invalid managed_identity parameter. " @@ -459,7 +506,8 @@ def _obtain_token( "See also " "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service") return _obtain_token_on_arc(http_client, arc_endpoint, resource) - return _obtain_token_on_azure_vm(http_client, managed_identity, resource) + return _obtain_token_on_azure_vm( + http_client, managed_identity, resource, client_claims=client_claims) def _adjust_param(params, managed_identity, types_mapping=None): @@ -469,7 +517,24 @@ def _adjust_param(params, managed_identity, types_mapping=None): if id_name: params[id_name] = managed_identity[ManagedIdentity.ID] -def _obtain_token_on_azure_vm(http_client, managed_identity, resource): +def _validate_msiv1_claims(client_claims): + """MSIv1 (IMDS v1) only supports the single ``xms_az_nwperimid`` custom claim. + + Any other top-level key makes IMDS return HTTP 400 with no useful diagnostic, + so validate early and raise a clear error. Mirrors MSAL .NET's + ``AbstractManagedIdentity.ValidateMsiv1Claims``. + """ + parsed = _parse_claims_or_raise(client_claims) + for key in parsed: + if key != _XMS_AZ_NWPERIMID: + raise ManagedIdentityError( + "MSIv1 (IMDS v1) only supports the `{expected}` custom claim. " + "The claims JSON contained the unsupported key `{actual}`. " + "Remove all keys other than `{expected}` when using client_claims " + "with MSIv1.".format(expected=_XMS_AZ_NWPERIMID, actual=key)) + + +def _obtain_token_on_azure_vm(http_client, managed_identity, resource, client_claims=None): # Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http logger.debug("Obtaining token via managed identity on Azure VM") params = { @@ -477,6 +542,10 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource): "resource": resource, } _adjust_param(params, managed_identity) + if client_claims: + # IMDS v1 (MSIv1) only supports the single xms_az_nwperimid claim. + _validate_msiv1_claims(client_claims) + params["claims"] = client_claims # http_client.get url-encodes query params resp = http_client.get( os.getenv( "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254" diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 4590d52d..5b38519c 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -253,6 +253,14 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(data or {}) # So the content in data param prevails _data = {k: v for k, v in _data.items() if v} # Clean up None values + # "client_claims" is a cache-key-only pseudo-parameter: callers merge its + # value into the standard "claims" body parameter upstream, and it is kept + # in the request data solely so it contributes to the extended cache key. + # It must not be sent on the wire. Popping it here (from this method's own + # local copy) keeps the wire body clean while the caller's data dict — used + # for the cache-add event — still carries it. + _data.pop("client_claims", None) + if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) diff --git a/msal/token_cache.py b/msal/token_cache.py index 0ca250df..a19db6a9 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -102,6 +102,57 @@ def _compute_ext_cache_key(data): return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() +def _parse_claims_or_raise(claims): + """Parse a claims JSON string into a dict, or raise a friendly ``ValueError``. + + The raw claims value is never included in the error message because it may + contain sensitive data. Mirrors MSAL .NET's ``ClaimsHelper.ParseClaimsOrThrow``. + """ + try: + parsed = json.loads(claims) + except ValueError as ex: # json.JSONDecodeError is a subclass of ValueError + raise ValueError( + "The claims value is not valid JSON. " + "See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter." + ) from ex + if not isinstance(parsed, dict): + # A valid JSON array, scalar, or the literal "null" is not a claims object. + raise ValueError( + "The claims value is not a valid JSON object. " + "See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter.") + return parsed + + +def _deep_merge_dict(base, overlay): + """Recursively merge ``overlay`` into ``base``, returning a new dict. + + Nested dicts are merged; for any other value type, ``overlay`` wins. + """ + result = dict(base) + for key, value in overlay.items(): + if (key in result + and isinstance(result[key], dict) and isinstance(value, dict)): + result[key] = _deep_merge_dict(result[key], value) + else: + result[key] = value + return result + + +def _merge_claims(claims_a, claims_b): + """Merge two claims JSON strings into a single JSON string. + + If either side is empty/None, the other is returned as-is. Mirrors MSAL + .NET's ``ClaimsHelper.MergeClaimsObjects``. + """ + if not claims_a: + return claims_b + if not claims_b: + return claims_a + merged = _deep_merge_dict( + _parse_claims_or_raise(claims_a), _parse_claims_or_raise(claims_b)) + return json.dumps(merged) + + def is_subdict_of(small, big): return dict(big, **small) == big diff --git a/tests/test_application.py b/tests/test_application.py index 31f77a71..11179858 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -966,6 +966,138 @@ def test_fmi_token_does_not_interfere_with_non_fmi_token(self): "Non-FMI call should not return FMI-cached token") +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenForClientWithClientClaims(unittest.TestCase): + """acquire_token_for_client(client_claims=...) forwards client-originated claims + via the OAuth "claims" body parameter, caches the result, and keys the cache + entry on the claims value.""" + + _CLIENT_CLAIMS = '{"access_token": {"xms_az_nwperimid": {"essential": true}}}' + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", + **kwargs) + + def test_client_claims_rejects_non_string_types(self): + app = self._build_app() + for bad_value in [123, True, ["claims"], {"a": "b"}, b"bytes"]: + with self.assertRaises(ValueError, + msg="client_claims={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], client_claims=bad_value) + + def test_client_claims_rejects_invalid_json(self): + app = self._build_app() + for bad_value in ["not json", "[1, 2]", "null", "123"]: + with self.assertRaises(ValueError, + msg="client_claims={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + result = app.acquire_token_for_client( + ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + self.assertIn("access_token", result) + # The client claims are forwarded via the standard OAuth "claims" parameter + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + # The cache-key-only pseudo-parameter must NOT leak onto the wire + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_client_claims_merged_with_client_capabilities(self): + app = self._build_app(client_capabilities=["CP1"]) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_for_client( + ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + merged = json.loads(captured_data["claims"]) + self.assertEqual( + { + "xms_cc": {"values": ["CP1"]}, + "xms_az_nwperimid": {"essential": True}, + }, + merged["access_token"], + "client_claims must merge with capability-derived claims") + self.assertNotIn("client_claims", captured_data) + + def test_same_client_claims_returns_cached_token(self): + app = self._build_app() + call_count = [0] + + def mock_post(url, headers=None, data=None, *args, **kwargs): + call_count[0] += 1 + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + result1 = app.acquire_token_for_client( + ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + self.assertEqual(result1[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP) + result2 = app.acquire_token_for_client( + ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, + "Same client_claims should return token from cache") + self.assertEqual(1, call_count[0], "Second call should not hit the IdP") + + def test_different_client_claims_are_cached_separately(self): + app = self._build_app() + + def mock_post_factory(token_value): + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": token_value, "expires_in": 3600})) + return mock_post + + claims_a = '{"access_token": {"xms_az_nwperimid": {"values": ["A"]}}}' + claims_b = '{"access_token": {"xms_az_nwperimid": {"values": ["B"]}}}' + + result_a = app.acquire_token_for_client( + ["scope"], client_claims=claims_a, post=mock_post_factory("AT_A")) + self.assertEqual("AT_A", result_a["access_token"]) + + result_b = app.acquire_token_for_client( + ["scope"], client_claims=claims_b, post=mock_post_factory("AT_B")) + self.assertEqual("AT_B", result_b["access_token"]) + self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "Different client_claims must NOT share a cache entry") + + result_a2 = app.acquire_token_for_client( + ["scope"], client_claims=claims_a, post=mock_post_factory("unused")) + self.assertEqual("AT_A", result_a2["access_token"]) + self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + + def test_client_claims_token_does_not_interfere_with_plain_token(self): + app = self._build_app() + app.acquire_token_for_client( + ["scope"], client_claims=self._CLIENT_CLAIMS, + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "claims_AT", "expires_in": 3600}))) + result = app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "plain_AT", "expires_in": 3600}))) + self.assertEqual("plain_AT", result["access_token"]) + self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "A plain request must not return a client_claims-cached token") + + @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestRemoveTokensForClient(unittest.TestCase): def test_remove_tokens_for_client_should_remove_client_tokens_only(self): diff --git a/tests/test_mi.py b/tests/test_mi.py index fd3834c8..8d53493d 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -262,6 +262,83 @@ def test_vm_resource_id_parameter_should_be_msi_res_id(self): uuid.UUID(corr_id) +class VmClientClaimsTestCase(ClientTestCase): + """client_claims is only supported on the IMDS (Azure VM) source, where it is + forwarded as the "claims" query parameter and keys the cache entry.""" + + _CLAIMS = '{"xms_az_nwperimid": {"essential": true}}' + + def _mock_get(self, token="AT"): + return patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "%s", "expires_in": "3600", "resource": "R"}' % token, + )) + + def test_client_claims_sent_as_claims_query_param(self): + with self._mock_get() as mock_get: + result = self.app.acquire_token_for_client( + resource="R", client_claims=self._CLAIMS) + self.assertIn("access_token", result) + self.assertEqual( + self._CLAIMS, mock_get.call_args.kwargs["params"].get("claims"), + "client_claims should be forwarded as the 'claims' query parameter") + + def test_no_claims_param_when_client_claims_absent(self): + with self._mock_get() as mock_get: + self.app.acquire_token_for_client(resource="R") + self.assertNotIn("claims", mock_get.call_args.kwargs["params"]) + + def test_msiv1_rejects_non_nwperimid_claim(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", + client_claims='{"some_other_claim": {"essential": true}}') + + def test_invalid_json_claims_raises(self): + for bad in ["not json", "[1, 2]", "null"]: + with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): + self.app.acquire_token_for_client(resource="R", client_claims=bad) + + def test_same_client_claims_hits_cache(self): + with self._mock_get("AT1") as mock_get: + r1 = self.app.acquire_token_for_client( + resource="R", client_claims=self._CLAIMS) + self.assertEqual("identity_provider", r1["token_source"]) + r2 = self.app.acquire_token_for_client( + resource="R", client_claims=self._CLAIMS) + self.assertEqual("cache", r2["token_source"], "Should hit cache") + self.assertEqual(1, mock_get.call_count, "Second call must not hit IMDS") + + def test_different_client_claims_are_cached_separately(self): + claims_a = '{"xms_az_nwperimid": {"values": ["A"]}}' + claims_b = '{"xms_az_nwperimid": {"values": ["B"]}}' + with self._mock_get("AT_A"): + ra = self.app.acquire_token_for_client( + resource="R", client_claims=claims_a) + self.assertEqual("AT_A", ra["access_token"]) + with self._mock_get("AT_B"): + rb = self.app.acquire_token_for_client( + resource="R", client_claims=claims_b) + self.assertEqual("AT_B", rb["access_token"]) + self.assertEqual("identity_provider", rb["token_source"], + "Different client_claims must NOT share a cache entry") + with self._mock_get("unused"): + ra2 = self.app.acquire_token_for_client( + resource="R", client_claims=claims_a) + self.assertEqual("AT_A", ra2["access_token"]) + self.assertEqual("cache", ra2["token_source"]) + + def test_plain_request_does_not_return_client_claims_token(self): + with self._mock_get("claims_AT"): + self.app.acquire_token_for_client( + resource="R", client_claims=self._CLAIMS) + with self._mock_get("plain_AT"): + result = self.app.acquire_token_for_client(resource="R") + self.assertEqual("plain_AT", result["access_token"]) + self.assertEqual("identity_provider", result["token_source"], + "A plain request must not return a client_claims-cached token") + + @patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) class AppServiceTestCase(ClientTestCase): @@ -302,6 +379,11 @@ def test_app_service_resource_id_parameter_should_be_mi_res_id(self): headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'}, ) + def test_client_claims_not_supported_on_app_service(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + @patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) class MachineLearningTestCase(ClientTestCase): @@ -327,6 +409,11 @@ def test_machine_learning_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_client_claims_not_supported_on_machine_learning(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + @patch.dict(os.environ, { "IDENTITY_ENDPOINT": "http://localhost", @@ -398,6 +485,11 @@ def test_sf_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_client_claims_not_supported_on_service_fabric(self): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + @patch.dict(os.environ, { "IDENTITY_ENDPOINT": "http://localhost/token", @@ -451,6 +543,11 @@ def test_arc_error_should_be_normalized(self, mocked_stat): if sys.platform in _supported_arc_platforms_and_their_prefixes: self.fail("Should not raise ArcPlatformNotSupportedError") + def test_client_claims_not_supported_on_arc(self, mocked_stat): + with self.assertRaises(ManagedIdentityError): + self.app.acquire_token_for_client( + resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + class GetManagedIdentitySourceTestCase(unittest.TestCase): diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index d7dfe8de..61a5206c 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -4,7 +4,10 @@ import time import warnings -from msal.token_cache import TokenCache, SerializableTokenCache, _compute_ext_cache_key +from msal.token_cache import ( + TokenCache, SerializableTokenCache, _compute_ext_cache_key, + _parse_claims_or_raise, _merge_claims, +) from tests import unittest @@ -362,6 +365,72 @@ def test_non_excluded_fields_are_included_in_hash(self): self.assertNotEqual(h1, h2, "Non-excluded fields should change the hash") +class TestClientClaimsCacheKey(unittest.TestCase): + """client_claims must drive the extended cache key; server-issued claims must not.""" + + def test_client_claims_produces_non_empty_hash(self): + result = _compute_ext_cache_key({"client_claims": '{"a": 1}'}) + self.assertNotEqual("", result) + self.assertIsInstance(result, str) + + def test_claims_is_excluded_but_client_claims_is_not(self): + # Server-issued "claims" (claims_challenge) must not affect the key, because + # it bypasses the cache. Client-originated "client_claims" must affect it. + self.assertEqual("", _compute_ext_cache_key({"claims": '{"a": 1}'})) + self.assertNotEqual("", _compute_ext_cache_key({"client_claims": '{"a": 1}'})) + + def test_same_client_claims_produce_same_hash(self): + self.assertEqual( + _compute_ext_cache_key({"client_claims": '{"a": 1}'}), + _compute_ext_cache_key({"client_claims": '{"a": 1}'})) + + def test_different_client_claims_produce_different_hashes(self): + self.assertNotEqual( + _compute_ext_cache_key({"client_claims": '{"a": 1}'}), + _compute_ext_cache_key({"client_claims": '{"a": 2}'})) + + def test_empty_client_claims_value_is_ignored(self): + self.assertEqual("", _compute_ext_cache_key({"client_claims": ""})) + + +class TestClaimsHelpers(unittest.TestCase): + """Tests for the shared _parse_claims_or_raise / _merge_claims helpers.""" + + def test_parse_valid_object(self): + self.assertEqual({"a": 1}, _parse_claims_or_raise('{"a": 1}')) + + def test_parse_rejects_non_object_and_malformed(self): + for bad in ["not json", "[1, 2]", "null", "123", '"a string"', "true"]: + with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): + _parse_claims_or_raise(bad) + + def test_parse_error_does_not_leak_raw_claims(self): + # A malformed payload that contains a secret-looking value + secret = '{"super": "secret-value-123"' # missing closing brace + with self.assertRaises(ValueError) as ctx: + _parse_claims_or_raise(secret) + self.assertNotIn("secret-value-123", str(ctx.exception), + "Error message must never echo the raw claims content") + + def test_merge_returns_other_when_one_side_is_empty(self): + self.assertEqual({"a": 1}, json.loads(_merge_claims(None, '{"a": 1}'))) + self.assertEqual({"a": 1}, json.loads(_merge_claims('{"a": 1}', None))) + self.assertEqual({"a": 1}, json.loads(_merge_claims("", '{"a": 1}'))) + self.assertEqual({"a": 1}, json.loads(_merge_claims('{"a": 1}', ""))) + + def test_merge_of_two_empties_is_falsy(self): + self.assertFalse(_merge_claims(None, None)) + self.assertFalse(_merge_claims("", "")) + + def test_merge_deep_merges_objects(self): + merged = json.loads(_merge_claims( + '{"access_token": {"xms_cc": {"values": ["cp1"]}}}', + '{"access_token": {"xms_az_nwperimid": {"essential": true}}}')) + self.assertEqual( + {"xms_cc": {"values": ["cp1"]}, "xms_az_nwperimid": {"essential": True}}, + merged["access_token"]) + + class TestExtCacheKeyIsolation(unittest.TestCase): """Tests that ext_cache_key provides proper cache isolation in TokenCache.""" From d9aa21bc78ca5df3ff143b6fe5eb15553b401a20 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 26 Jun 2026 14:56:44 -0400 Subject: [PATCH 2/9] Extend client_claims to all confidential client flows (OBO, FIC, auth code, silent) Phase 1 added client_claims to acquire_token_for_client. This extends it to the remaining confidential client flows so client-originated claims are forwarded and cache-isolated consistently, mirroring msal-dotnet PR 5999's WithClaimsFromClient (which applies to all confidential client builders): - acquire_token_on_behalf_of (OBO) - acquire_token_by_user_federated_identity_credential (FIC) - acquire_token_by_authorization_code - acquire_token_silent / acquire_token_silent_with_error (cache-read isolation plus refresh-token request merge) A shared _stash_client_claims() helper validates the value and stashes it into the request data, so it (a) contributes to the extended cache key and (b) is merged into the OAuth "claims" parameter while being stripped from the wire body. Adds unit tests for each flow (wire shape, validation, cache isolation). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/application.py | 108 +++++++++++++--- tests/test_application.py | 259 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 353 insertions(+), 14 deletions(-) diff --git a/msal/application.py b/msal/application.py index e680e3d2..87057d68 100644 --- a/msal/application.py +++ b/msal/application.py @@ -64,6 +64,30 @@ def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge): return json.dumps(claims_dict) +def _stash_client_claims(client_claims, data): + """Validate ``client_claims`` and stash it into the request ``data`` dict. + + ``client_claims`` carries *client-originated* claims (for example a network + security perimeter ``xms_az_nwperimid`` claim). The raw value is stored in + ``data`` so that it (a) contributes to the extended cache key -- isolating + cache entries by claims value -- and (b) is stripped from the request body + by the oauth2 layer (it reaches the wire only after being merged into the + standard OAuth ``claims`` parameter). ``data`` is mutated in place. + + Unlike ``claims_challenge`` (server-issued, which bypasses the cache), + ``client_claims`` tokens are cached and keyed on the claims value. A no-op + when ``client_claims`` is ``None``. + """ + if client_claims is None: + return + if not isinstance(client_claims, str): + raise ValueError( + "client_claims must be a string, got {}".format( + type(client_claims).__name__)) + _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON + data["client_claims"] = client_claims + + def _str2bytes(raw): # A conversion based on duck-typing rather than six.text_type try: @@ -1237,6 +1261,7 @@ def acquire_token_by_authorization_code( # values MUST be identical. nonce=None, claims_challenge=None, + client_claims=None, **kwargs): """The second half of the Authorization Code Grant. @@ -1267,6 +1292,14 @@ def acquire_token_by_authorization_code( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str client_claims: + Optional. A JSON string of *client-originated* claims (for example + a network security perimeter ``xms_az_nwperimid`` claim) to include + in the token request. Unlike ``claims_challenge`` (server-issued, + which bypasses the cache), tokens acquired with ``client_claims`` + **are cached** and keyed on the claims value, so use stable, + non-dynamic values. The value is merged into the standard OAuth + ``claims`` request parameter sent on the wire. :return: A dict representing the json response from Microsoft Entra: @@ -1286,14 +1319,18 @@ def acquire_token_by_authorization_code( with warnings.catch_warnings(record=True): telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) + _data = kwargs.pop("data", {}) + _stash_client_claims(client_claims, _data) response = _clean_up(self.client.obtain_token_by_authorization_code( code, redirect_uri=redirect_uri, scope=self._decorate_scope(scopes), headers=telemetry_context.generate_headers(), data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + _data, + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), nonce=nonce, **kwargs)) if "access_token" in response: @@ -1474,6 +1511,7 @@ def acquire_token_silent( authority=None, # See get_authorization_request_url() force_refresh=False, # type: Optional[boolean] claims_challenge=None, + client_claims=None, auth_scheme=None, **kwargs): """Acquire an access token for given account, without user interaction. @@ -1493,6 +1531,9 @@ def acquire_token_silent( """ if not account: return None # A backward-compatible NO-OP to drop the account=None usage + if client_claims is not None: + kwargs["data"] = kwargs.get("data") or {} + _stash_client_claims(client_claims, kwargs["data"]) result = _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, auth_scheme=auth_scheme, **kwargs)) @@ -1505,6 +1546,7 @@ def acquire_token_silent_with_error( authority=None, # See get_authorization_request_url() force_refresh=False, # type: Optional[boolean] claims_challenge=None, + client_claims=None, auth_scheme=None, **kwargs): """Acquire an access token for given account, without user interaction. @@ -1532,6 +1574,12 @@ def acquire_token_silent_with_error( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str client_claims: + Optional. A JSON string of *client-originated* claims (for example + a network security perimeter ``xms_az_nwperimid`` claim) to include + when a cached token is missing and a network request is made. Tokens + are **cached** and keyed on the claims value (different values yield + separate cache entries), so use stable, non-dynamic values. :param object auth_scheme: You can provide an ``msal.auth_scheme.PopAuthScheme`` object so that MSAL will get a Proof-of-Possession (POP) token for you. @@ -1547,6 +1595,9 @@ def acquire_token_silent_with_error( """ if not account: return None # A backward-compatible NO-OP to drop the account=None usage + if client_claims is not None: + kwargs["data"] = kwargs.get("data") or {} + _stash_client_claims(client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, auth_scheme=auth_scheme, **kwargs)) @@ -1809,6 +1860,9 @@ def _acquire_token_silent_by_finding_specific_refresh_token( telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_SILENT_ID, correlation_id=correlation_id, refresh_reason=refresh_reason) + # Pop "data" once (rather than per-iteration) so client_claims and any + # other data fields apply consistently across all candidate RTs. + _data = kwargs.pop("data", {}) for entry in sorted( # Since unfit RTs would not be aggressively removed, # we start from newer RTs which are more likely fit. matches, @@ -1832,9 +1886,11 @@ def _acquire_token_silent_by_finding_specific_refresh_token( scope=scopes, headers=headers, data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + _data, + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), **kwargs) telemetry_context.update_telemetry(response) if "error" not in response: @@ -2608,7 +2664,7 @@ def remove_tokens_for_client(self): self.token_cache.remove_at(at) # acquire_token_for_client() obtains no RTs, so we have no RT to remove - def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): + def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, client_claims=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. The current app is a middle-tier service which was called with a token @@ -2628,6 +2684,14 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str client_claims: + Optional. A JSON string of *client-originated* claims (for example + a network security perimeter ``xms_az_nwperimid`` claim) to include + in the token request. Unlike ``claims_challenge`` (server-issued, + which bypasses the cache), tokens acquired with ``client_claims`` + **are cached** and keyed on the claims value, so use stable, + non-dynamic values. The value is merged into the standard OAuth + ``claims`` request parameter sent on the wire. :return: A dict representing the json response from Microsoft Entra: @@ -2636,6 +2700,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No """ telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID) + _data = kwargs.pop("data", {}) + _stash_client_claims(client_claims, _data) # The implementation is NOT based on Token Exchange (RFC 8693) response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 user_assertion, @@ -2647,10 +2713,12 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No # so that the calling app could use id_token_claims to implement # their own cache mapping, which is likely needed in web apps. data=dict( - kwargs.pop("data", {}), + _data, requested_token_use="on_behalf_of", - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), headers=telemetry_context.generate_headers(), # TBD: Expose a login_hint (or ccs_routing_hint) param for web app **kwargs)) @@ -2661,7 +2729,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No def acquire_token_by_user_federated_identity_credential( self, scopes, assertion, username=None, user_object_id=None, - claims_challenge=None, **kwargs): + claims_challenge=None, client_claims=None, **kwargs): """Acquires a user-scoped token using the ``user_fic`` grant type. This method exchanges a federated identity credential (typically an @@ -2684,6 +2752,14 @@ def acquire_token_by_user_federated_identity_credential( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. + :param str client_claims: + Optional. A JSON string of *client-originated* claims (for example + a network security perimeter ``xms_az_nwperimid`` claim) to include + in the token request. Unlike ``claims_challenge`` (server-issued, + which bypasses the cache), tokens acquired with ``client_claims`` + **are cached** and keyed on the claims value, so use stable, + non-dynamic values. The value is merged into the standard OAuth + ``claims`` request parameter sent on the wire. :return: A dict representing the json response from Microsoft Entra: @@ -2708,6 +2784,8 @@ def acquire_token_by_user_federated_identity_credential( elif user_object_id: headers["X-AnchorMailbox"] = "Oid:{}@{}".format( user_object_id, self.authority.tenant) + _data = kwargs.pop("data", {}) + _stash_client_claims(client_claims, _data) response = _clean_up(self.client.obtain_token_by_user_fic( scope=self._decorate_scope(scopes), assertion=assertion, @@ -2715,9 +2793,11 @@ def acquire_token_by_user_federated_identity_credential( user_object_id=user_object_id, headers=headers, data=dict( - kwargs.pop("data", {}), - claims=_merge_claims_challenge_and_capabilities( - self._client_capabilities, claims_challenge)), + _data, + claims=_merge_claims( + _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + _data.get("client_claims"))), **kwargs)) if "access_token" in response: response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP diff --git a/tests/test_application.py b/tests/test_application.py index 11179858..5c3bd8cc 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1098,6 +1098,265 @@ def test_client_claims_token_does_not_interfere_with_plain_token(self): "A plain request must not return a client_claims-cached token") +def _build_user_token_response( + access_token="user_at", uid="user_oid", utid="my_tenant", + client_id="client_id", refresh_token=None): + """A mock user-token response (AT + id_token + client_info), optionally with + a refresh token, so that an account is created and silent retrieval works.""" + extra = {"id_token": build_id_token( + aud=client_id, oid=uid, tid=utid, preferred_username="user@contoso.com")} + if refresh_token: + extra["refresh_token"] = refresh_token + return json.dumps(build_response( + uid=uid, utid=utid, access_token=access_token, **extra)) + + +_CLIENT_CLAIMS = '{"access_token": {"xms_az_nwperimid": {"essential": true}}}' +_OTHER_CLIENT_CLAIMS = '{"access_token": {"xms_az_nwperimid": {"values": ["other"]}}}' + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenOnBehalfOfWithClientClaims(unittest.TestCase): + """acquire_token_on_behalf_of(client_claims=...) forwards client-originated + claims via the OAuth "claims" parameter and isolates the cached token.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def test_client_claims_rejects_invalid_values(self): + app = self._build_app() + for bad_value in [123, True, ["claims"], b"bytes", "not json", "null", "[1,2]"]: + with self.assertRaises(ValueError, + msg="client_claims={!r} should raise".format(bad_value)): + app.acquire_token_on_behalf_of( + "assertion", ["s"], client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_on_behalf_of( + "assertion", ["s"], client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_client_claims_merged_with_client_capabilities(self): + app = self._build_app(client_capabilities=["CP1"]) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_on_behalf_of( + "assertion", ["s"], client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertEqual( + { + "xms_cc": {"values": ["CP1"]}, + "xms_az_nwperimid": {"essential": True}, + }, + json.loads(captured_data["claims"])["access_token"]) + + def test_cached_token_is_isolated_by_client_claims(self): + app = self._build_app() + app.acquire_token_on_behalf_of( + "assertion", ["s"], client_claims=_CLIENT_CLAIMS, + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response(access_token="obo_at"))) + accounts = app.get_accounts() + self.assertTrue(accounts, "OBO response should create an account") + hit = app.acquire_token_silent( + ["s"], accounts[0], client_claims=_CLIENT_CLAIMS) + self.assertIsNotNone(hit) + self.assertEqual("obo_at", hit["access_token"]) + self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + self.assertIsNone( + app.acquire_token_silent( + ["s"], accounts[0], client_claims=_OTHER_CLIENT_CLAIMS), + "Different client_claims must not read the cached token") + self.assertIsNone( + app.acquire_token_silent(["s"], accounts[0]), + "A plain silent call must not read a client_claims token") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenByAuthorizationCodeWithClientClaims(unittest.TestCase): + """acquire_token_by_authorization_code(client_claims=...) forwards + client-originated claims and isolates the cached token.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def test_client_claims_rejects_invalid_values(self): + app = self._build_app() + for bad_value in [123, ["claims"], b"bytes", "not json", "null"]: + with self.assertRaises(ValueError, + msg="client_claims={!r} should raise".format(bad_value)): + app.acquire_token_by_authorization_code( + "code", ["s"], client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + app.acquire_token_by_authorization_code( + "code", ["s"], client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_cached_token_is_isolated_by_client_claims(self): + app = self._build_app() + app.acquire_token_by_authorization_code( + "code", ["s"], client_claims=_CLIENT_CLAIMS, + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response(access_token="authcode_at"))) + accounts = app.get_accounts() + self.assertTrue(accounts) + hit = app.acquire_token_silent( + ["s"], accounts[0], client_claims=_CLIENT_CLAIMS) + self.assertIsNotNone(hit) + self.assertEqual("authcode_at", hit["access_token"]) + self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + self.assertIsNone( + app.acquire_token_silent(["s"], accounts[0]), + "A plain silent call must not read a client_claims token") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestUserFicWithClientClaims(unittest.TestCase): + """acquire_token_by_user_federated_identity_credential(client_claims=...) + forwards client-originated claims and isolates the cached token.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "agent_app_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def test_client_claims_rejects_invalid_values(self): + app = self._build_app() + for bad_value in [123, ["claims"], b"bytes", "not json", "null"]: + with self.assertRaises(ValueError, + msg="client_claims={!r} should raise".format(bad_value)): + app.acquire_token_by_user_federated_identity_credential( + ["s"], assertion="t2", username="user@contoso.com", + client_claims=bad_value) + + def test_client_claims_sent_as_claims_on_the_wire(self): + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_token_response( + client_id="agent_app_id")) + + app.acquire_token_by_user_federated_identity_credential( + ["s"], assertion="t2", username="user@contoso.com", + client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not be sent in the HTTP request body") + + def test_cached_token_is_isolated_by_client_claims(self): + app = self._build_app() + app.acquire_token_by_user_federated_identity_credential( + ["s"], assertion="t2", username="user@contoso.com", + client_claims=_CLIENT_CLAIMS, + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response( + access_token="fic_at", client_id="agent_app_id"))) + accounts = app.get_accounts() + self.assertTrue(accounts) + hit = app.acquire_token_silent( + ["s"], accounts[0], client_claims=_CLIENT_CLAIMS) + self.assertIsNotNone(hit) + self.assertEqual("fic_at", hit["access_token"]) + self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) + self.assertIsNone( + app.acquire_token_silent(["s"], accounts[0]), + "A plain silent call must not read a client_claims token") + + +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenSilentWithClientClaims(unittest.TestCase): + """acquire_token_silent(client_claims=...) isolates cache reads and merges + the claims into the refresh-token request sent on the wire.""" + + def _build_app(self, **kwargs): + return ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant", **kwargs) + + def _seed_account_with_rt(self, app): + app.acquire_token_on_behalf_of( + "assertion", ["s"], + post=lambda *a, **k: MinimalResponse(status_code=200, + text=_build_user_token_response( + access_token="seed_at", refresh_token="seed_rt"))) + accounts = app.get_accounts() + self.assertTrue(accounts) + return accounts[0] + + def test_client_claims_merged_into_refresh_request(self): + app = self._build_app() + account = self._seed_account_with_rt(app) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=_build_user_token_response( + access_token="refreshed_at", refresh_token="seed_rt")) + + result = app.acquire_token_silent( + ["s"], account, force_refresh=True, + client_claims=_CLIENT_CLAIMS, post=mock_post) + self.assertIsNotNone(result) + self.assertEqual("refresh_token", captured_data.get("grant_type")) + self.assertIn("claims", captured_data) + self.assertEqual( + {"access_token": {"xms_az_nwperimid": {"essential": True}}}, + json.loads(captured_data["claims"])) + self.assertNotIn("client_claims", captured_data, + "client_claims must not leak onto the refresh request body") + + def test_both_silent_entry_points_validate_client_claims(self): + app = self._build_app() + account = self._seed_account_with_rt(app) + for bad_value in [123, ["claims"], "not json", "null"]: + with self.assertRaises(ValueError): + app.acquire_token_silent( + ["s"], account, client_claims=bad_value) + with self.assertRaises(ValueError): + app.acquire_token_silent_with_error( + ["s"], account, client_claims=bad_value) + + @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestRemoveTokensForClient(unittest.TestCase): def test_remove_tokens_for_client_should_remove_client_tokens_only(self): From dd8351ac649c2427b897fce085d265ffe5cdc350 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 26 Jun 2026 15:30:49 -0400 Subject: [PATCH 3/9] Address Copilot review: harden client_claims validation and docs - token_cache._parse_claims_or_raise now also catches TypeError (raised when the input is not a str/bytes) and surfaces the same friendly ValueError, so every caller behaves consistently regardless of the bad input's type. - ManagedIdentityClient.acquire_token_for_client now rejects non-string client_claims with a ValueError (mirroring the confidential-client flows), preventing a raw TypeError leak and inconsistent extended-cache-key hashing. - ConfidentialClientApplication.acquire_token_for_client now reuses the shared _stash_client_claims() helper instead of duplicating the validate-and-stash logic, removing the risk of the two paths diverging. - Add cross-referencing docstring notes disambiguating the per-request client_claims (a JSON string forwarded in the request) from the pre-existing constructor client_claims (a dict of claims signed into the client-assertion JWT). - Add unit tests for non-string client_claims on managed identity and for non-string inputs to _parse_claims_or_raise. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/application.py | 36 ++++++++++++++++++++++++++++++------ msal/managed_identity.py | 4 ++++ msal/token_cache.py | 6 +++++- tests/test_mi.py | 9 +++++++++ tests/test_token_cache.py | 9 +++++++++ 5 files changed, 57 insertions(+), 7 deletions(-) diff --git a/msal/application.py b/msal/application.py index 87057d68..2def085d 100644 --- a/msal/application.py +++ b/msal/application.py @@ -448,6 +448,15 @@ def get_client_assertion(): "jti": a_random_uuid } + .. note:: + + This *constructor* ``client_claims`` (a ``dict`` signed into the + client-assertion JWT) is distinct from the per-request + ``client_claims`` parameter (a JSON string of client-originated + claims forwarded in the token request) accepted by + ``acquire_token_for_client``, ``acquire_token_on_behalf_of``, + ``acquire_token_silent``, and the other token-acquisition methods. + :param str authority: A URL that identifies a token authority. It should be of the format ``https://login.microsoftonline.com/your_tenant`` @@ -1301,6 +1310,10 @@ def acquire_token_by_authorization_code( non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. + This per-request ``client_claims`` (a JSON string) is distinct from + the ``client_claims`` *constructor* parameter, which is a ``dict`` + of extra claims signed into the client-assertion JWT. + :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -1580,6 +1593,10 @@ def acquire_token_silent_with_error( when a cached token is missing and a network request is made. Tokens are **cached** and keyed on the claims value (different values yield separate cache entries), so use stable, non-dynamic values. + + This per-request ``client_claims`` (a JSON string) is distinct from + the ``client_claims`` *constructor* parameter, which is a ``dict`` + of extra claims signed into the client-assertion JWT. :param object auth_scheme: You can provide an ``msal.auth_scheme.PopAuthScheme`` object so that MSAL will get a Proof-of-Possession (POP) token for you. @@ -2586,6 +2603,10 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, cache entries, so use stable, non-dynamic values to avoid unbounded cache growth. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. + + This per-request ``client_claims`` (a JSON string) is distinct from + the ``client_claims`` *constructor* parameter, which is a ``dict`` + of extra claims signed into the client-assertion JWT. :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -2602,17 +2623,12 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, kwargs["data"] = kwargs.get("data", {}) kwargs["data"]["fmi_path"] = fmi_path if client_claims is not None: - if not isinstance(client_claims, str): - raise ValueError( - "client_claims must be a string, got {}".format( - type(client_claims).__name__)) - _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON # Carry it in the request data so it contributes to the extended # cache key (different claims => separate cache entries). It is # merged into the "claims" body parameter in _acquire_token_for_client # and stripped from the wire body by the oauth2 layer. kwargs["data"] = kwargs.get("data", {}) - kwargs["data"]["client_claims"] = client_claims + _stash_client_claims(client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, None, claims_challenge=claims_challenge, **kwargs)) @@ -2693,6 +2709,10 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. + This per-request ``client_claims`` (a JSON string) is distinct from + the ``client_claims`` *constructor* parameter, which is a ``dict`` + of extra claims signed into the client-assertion JWT. + :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -2761,6 +2781,10 @@ def acquire_token_by_user_federated_identity_credential( non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. + This per-request ``client_claims`` (a JSON string) is distinct from + the ``client_claims`` *constructor* parameter, which is a ``dict`` + of extra claims signed into the client-assertion JWT. + :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 479423bb..bc452a0d 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -320,6 +320,10 @@ def acquire_token_for_client( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() if client_claims is not None: + if not isinstance(client_claims, str): + raise ValueError( + "client_claims must be a string, got {}".format( + type(client_claims).__name__)) _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON # Client-originated claims isolate the cache: a distinct claims value gets # a distinct cache entry. (Server-issued claims_challenge, by contrast, diff --git a/msal/token_cache.py b/msal/token_cache.py index a19db6a9..7bc21a43 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -110,7 +110,11 @@ def _parse_claims_or_raise(claims): """ try: parsed = json.loads(claims) - except ValueError as ex: # json.JSONDecodeError is a subclass of ValueError + except (ValueError, TypeError) as ex: + # json.JSONDecodeError (malformed JSON) is a subclass of ValueError; + # TypeError is raised when *claims* is not a str/bytes/bytearray. Both + # are surfaced as the same friendly ValueError so every caller behaves + # consistently regardless of the bad input's type. raise ValueError( "The claims value is not valid JSON. " "See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter." diff --git a/tests/test_mi.py b/tests/test_mi.py index 8d53493d..d1f5c572 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -299,6 +299,15 @@ def test_invalid_json_claims_raises(self): with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): self.app.acquire_token_for_client(resource="R", client_claims=bad) + def test_non_string_client_claims_raises_value_error(self): + # A non-str client_claims (int, bytes, dict, ...) must raise a friendly + # ValueError rather than leaking a raw TypeError from json.loads or + # hashing inconsistently into the extended cache key. + for bad in [123, b'{"xms_az_nwperimid": {}}', {"xms_az_nwperimid": {}}]: + with self.assertRaises( + ValueError, msg="{!r} should raise ValueError".format(bad)): + self.app.acquire_token_for_client(resource="R", client_claims=bad) + def test_same_client_claims_hits_cache(self): with self._mock_get("AT1") as mock_get: r1 = self.app.acquire_token_for_client( diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 61a5206c..149d0aba 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -404,6 +404,15 @@ def test_parse_rejects_non_object_and_malformed(self): with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): _parse_claims_or_raise(bad) + def test_parse_rejects_non_string_types(self): + # Non-str/bytes inputs make json.loads raise TypeError; the helper must + # surface the same friendly ValueError so callers behave consistently + # regardless of the bad input's type. + for bad in [123, None, 1.5, True, ["a"], {"a": 1}]: + with self.assertRaises( + ValueError, msg="{!r} should raise ValueError".format(bad)): + _parse_claims_or_raise(bad) + def test_parse_error_does_not_leak_raw_claims(self): # A malformed payload that contains a secret-looking value secret = '{"super": "secret-value-123"' # missing closing brace From 6b8fac0622f07d730e3d950a884b6e41f07b32a2 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Tue, 30 Jun 2026 17:47:38 -0400 Subject: [PATCH 4/9] Rename per-request param to forwarded_client_claims; add merge test Address reviewer feedback on PR #937: - Rename the new per-request parameter `client_claims` -> `forwarded_client_claims` across all confidential-client flows (acquire_token_for_client, on_behalf_of, user FIC, auth code, silent) and the Managed Identity acquire_token_for_client. This removes the naming collision with the pre-existing `client_claims` *constructor* parameter (a dict signed into the client-assertion JWT), which a second reviewer also flagged as confusing. The public keyword is the only thing renamed. The internal request-data key "client_claims" (used by the oauth2 wire-strip, _compute_ext_cache_key cache isolation, and _merge_claims) and the private Managed Identity plumbing keep their existing names, so cache keying and wire behavior are unchanged. - Add test_forwarded_client_claims_merged_with_claims_challenge, covering the previously untested three-way merge of server-issued claims_challenge + client capabilities + forwarded_client_claims into the single OAuth "claims" wire parameter. 200 passed across test_token_cache.py, test_mi.py, test_application.py. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/application.py | 112 ++++++++++++++++++-------------------- msal/managed_identity.py | 28 +++++----- tests/test_application.py | 101 +++++++++++++++++++++------------- tests/test_mi.py | 28 +++++----- 4 files changed, 146 insertions(+), 123 deletions(-) diff --git a/msal/application.py b/msal/application.py index 2def085d..74dd4dc8 100644 --- a/msal/application.py +++ b/msal/application.py @@ -64,28 +64,29 @@ def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge): return json.dumps(claims_dict) -def _stash_client_claims(client_claims, data): - """Validate ``client_claims`` and stash it into the request ``data`` dict. +def _stash_client_claims(forwarded_client_claims, data): + """Validate ``forwarded_client_claims`` and stash it into the request ``data``. - ``client_claims`` carries *client-originated* claims (for example a network - security perimeter ``xms_az_nwperimid`` claim). The raw value is stored in - ``data`` so that it (a) contributes to the extended cache key -- isolating - cache entries by claims value -- and (b) is stripped from the request body - by the oauth2 layer (it reaches the wire only after being merged into the - standard OAuth ``claims`` parameter). ``data`` is mutated in place. + ``forwarded_client_claims`` carries *client-originated* claims (for example a + network security perimeter ``xms_az_nwperimid`` claim). The raw value is + stored in ``data`` (under the internal ``client_claims`` key) so that it + (a) contributes to the extended cache key -- isolating cache entries by + claims value -- and (b) is stripped from the request body by the oauth2 + layer (it reaches the wire only after being merged into the standard OAuth + ``claims`` parameter). ``data`` is mutated in place. Unlike ``claims_challenge`` (server-issued, which bypasses the cache), - ``client_claims`` tokens are cached and keyed on the claims value. A no-op - when ``client_claims`` is ``None``. + ``forwarded_client_claims`` tokens are cached and keyed on the claims value. + A no-op when ``forwarded_client_claims`` is ``None``. """ - if client_claims is None: + if forwarded_client_claims is None: return - if not isinstance(client_claims, str): + if not isinstance(forwarded_client_claims, str): raise ValueError( - "client_claims must be a string, got {}".format( - type(client_claims).__name__)) - _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON - data["client_claims"] = client_claims + "forwarded_client_claims must be a string, got {}".format( + type(forwarded_client_claims).__name__)) + _parse_claims_or_raise(forwarded_client_claims) # Fail fast on malformed JSON + data["client_claims"] = forwarded_client_claims def _str2bytes(raw): @@ -452,7 +453,7 @@ def get_client_assertion(): This *constructor* ``client_claims`` (a ``dict`` signed into the client-assertion JWT) is distinct from the per-request - ``client_claims`` parameter (a JSON string of client-originated + ``forwarded_client_claims`` parameter (a JSON string of client-originated claims forwarded in the token request) accepted by ``acquire_token_for_client``, ``acquire_token_on_behalf_of``, ``acquire_token_silent``, and the other token-acquisition methods. @@ -1270,7 +1271,7 @@ def acquire_token_by_authorization_code( # values MUST be identical. nonce=None, claims_challenge=None, - client_claims=None, + forwarded_client_claims=None, **kwargs): """The second half of the Authorization Code Grant. @@ -1301,18 +1302,17 @@ def acquire_token_by_authorization_code( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. - :param str client_claims: + :param str forwarded_client_claims: Optional. A JSON string of *client-originated* claims (for example a network security perimeter ``xms_az_nwperimid`` claim) to include in the token request. Unlike ``claims_challenge`` (server-issued, - which bypasses the cache), tokens acquired with ``client_claims`` + which bypasses the cache), tokens acquired with ``forwarded_client_claims`` **are cached** and keyed on the claims value, so use stable, non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. - This per-request ``client_claims`` (a JSON string) is distinct from - the ``client_claims`` *constructor* parameter, which is a ``dict`` - of extra claims signed into the client-assertion JWT. + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: @@ -1333,7 +1333,7 @@ def acquire_token_by_authorization_code( telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID) _data = kwargs.pop("data", {}) - _stash_client_claims(client_claims, _data) + _stash_client_claims(forwarded_client_claims, _data) response = _clean_up(self.client.obtain_token_by_authorization_code( code, redirect_uri=redirect_uri, scope=self._decorate_scope(scopes), @@ -1524,7 +1524,7 @@ def acquire_token_silent( authority=None, # See get_authorization_request_url() force_refresh=False, # type: Optional[boolean] claims_challenge=None, - client_claims=None, + forwarded_client_claims=None, auth_scheme=None, **kwargs): """Acquire an access token for given account, without user interaction. @@ -1544,9 +1544,9 @@ def acquire_token_silent( """ if not account: return None # A backward-compatible NO-OP to drop the account=None usage - if client_claims is not None: + if forwarded_client_claims is not None: kwargs["data"] = kwargs.get("data") or {} - _stash_client_claims(client_claims, kwargs["data"]) + _stash_client_claims(forwarded_client_claims, kwargs["data"]) result = _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, auth_scheme=auth_scheme, **kwargs)) @@ -1559,7 +1559,7 @@ def acquire_token_silent_with_error( authority=None, # See get_authorization_request_url() force_refresh=False, # type: Optional[boolean] claims_challenge=None, - client_claims=None, + forwarded_client_claims=None, auth_scheme=None, **kwargs): """Acquire an access token for given account, without user interaction. @@ -1587,16 +1587,15 @@ def acquire_token_silent_with_error( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. - :param str client_claims: + :param str forwarded_client_claims: Optional. A JSON string of *client-originated* claims (for example a network security perimeter ``xms_az_nwperimid`` claim) to include when a cached token is missing and a network request is made. Tokens are **cached** and keyed on the claims value (different values yield separate cache entries), so use stable, non-dynamic values. - This per-request ``client_claims`` (a JSON string) is distinct from - the ``client_claims`` *constructor* parameter, which is a ``dict`` - of extra claims signed into the client-assertion JWT. + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :param object auth_scheme: You can provide an ``msal.auth_scheme.PopAuthScheme`` object so that MSAL will get a Proof-of-Possession (POP) token for you. @@ -1612,9 +1611,9 @@ def acquire_token_silent_with_error( """ if not account: return None # A backward-compatible NO-OP to drop the account=None usage - if client_claims is not None: + if forwarded_client_claims is not None: kwargs["data"] = kwargs.get("data") or {} - _stash_client_claims(client_claims, kwargs["data"]) + _stash_client_claims(forwarded_client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, claims_challenge=claims_challenge, auth_scheme=auth_scheme, **kwargs)) @@ -2567,7 +2566,7 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app except that ``allow_broker`` parameter shall remain ``None``. """ - def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, client_claims=None, **kwargs): + def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, forwarded_client_claims=None, **kwargs): """Acquires token for the current confidential client, not for an end user. Since MSAL Python 1.23, it will automatically look for token from cache, @@ -2591,22 +2590,21 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, scopes=["api://resource/.default"], fmi_path="SomeFmiPath/FmiCredentialPath", ) - :param str client_claims: + :param str forwarded_client_claims: Optional. A JSON string containing *client-originated* claims to include in the token request (for example a network security perimeter ``xms_az_nwperimid`` claim). Unlike ``claims_challenge`` (which carries *server-issued* claims challenges and bypasses the cache), tokens acquired with - ``client_claims`` **are cached**, and the cache entry is keyed on the - claims value. Different ``client_claims`` values produce separate + ``forwarded_client_claims`` **are cached**, and the cache entry is keyed on the + claims value. Different ``forwarded_client_claims`` values produce separate cache entries, so use stable, non-dynamic values to avoid unbounded cache growth. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. - This per-request ``client_claims`` (a JSON string) is distinct from - the ``client_claims`` *constructor* parameter, which is a ``dict`` - of extra claims signed into the client-assertion JWT. + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -2622,13 +2620,13 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, "fmi_path must be a string, got {}".format(type(fmi_path).__name__)) kwargs["data"] = kwargs.get("data", {}) kwargs["data"]["fmi_path"] = fmi_path - if client_claims is not None: + if forwarded_client_claims is not None: # Carry it in the request data so it contributes to the extended # cache key (different claims => separate cache entries). It is # merged into the "claims" body parameter in _acquire_token_for_client # and stripped from the wire body by the oauth2 layer. kwargs["data"] = kwargs.get("data", {}) - _stash_client_claims(client_claims, kwargs["data"]) + _stash_client_claims(forwarded_client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, None, claims_challenge=claims_challenge, **kwargs)) @@ -2651,7 +2649,7 @@ def _acquire_token_for_client( request_data = kwargs.pop("data", {}) claims = _merge_claims_challenge_and_capabilities( self._client_capabilities, claims_challenge) - # Client-originated claims (set via client_claims=) are merged into the + # Client-originated claims (set via forwarded_client_claims=) are merged into the # same OAuth "claims" parameter and sent on the wire. The raw # "client_claims" entry stays in request_data so it keys the cache; the # oauth2 layer drops it from the actual request body. @@ -2680,7 +2678,7 @@ def remove_tokens_for_client(self): self.token_cache.remove_at(at) # acquire_token_for_client() obtains no RTs, so we have no RT to remove - def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, client_claims=None, **kwargs): + def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, forwarded_client_claims=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. The current app is a middle-tier service which was called with a token @@ -2700,18 +2698,17 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. - :param str client_claims: + :param str forwarded_client_claims: Optional. A JSON string of *client-originated* claims (for example a network security perimeter ``xms_az_nwperimid`` claim) to include in the token request. Unlike ``claims_challenge`` (server-issued, - which bypasses the cache), tokens acquired with ``client_claims`` + which bypasses the cache), tokens acquired with ``forwarded_client_claims`` **are cached** and keyed on the claims value, so use stable, non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. - This per-request ``client_claims`` (a JSON string) is distinct from - the ``client_claims`` *constructor* parameter, which is a ``dict`` - of extra claims signed into the client-assertion JWT. + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: @@ -2721,7 +2718,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No telemetry_context = self._build_telemetry_context( self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID) _data = kwargs.pop("data", {}) - _stash_client_claims(client_claims, _data) + _stash_client_claims(forwarded_client_claims, _data) # The implementation is NOT based on Token Exchange (RFC 8693) response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521 user_assertion, @@ -2749,7 +2746,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No def acquire_token_by_user_federated_identity_credential( self, scopes, assertion, username=None, user_object_id=None, - claims_challenge=None, client_claims=None, **kwargs): + claims_challenge=None, forwarded_client_claims=None, **kwargs): """Acquires a user-scoped token using the ``user_fic`` grant type. This method exchanges a federated identity credential (typically an @@ -2772,18 +2769,17 @@ def acquire_token_by_user_federated_identity_credential( in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. - :param str client_claims: + :param str forwarded_client_claims: Optional. A JSON string of *client-originated* claims (for example a network security perimeter ``xms_az_nwperimid`` claim) to include in the token request. Unlike ``claims_challenge`` (server-issued, - which bypasses the cache), tokens acquired with ``client_claims`` + which bypasses the cache), tokens acquired with ``forwarded_client_claims`` **are cached** and keyed on the claims value, so use stable, non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. - This per-request ``client_claims`` (a JSON string) is distinct from - the ``client_claims`` *constructor* parameter, which is a ``dict`` - of extra claims signed into the client-assertion JWT. + Not to be confused with the constructor ``client_claims`` parameter + (a ``dict`` of extra claims signed into the client-assertion JWT). :return: A dict representing the json response from Microsoft Entra: @@ -2809,7 +2805,7 @@ def acquire_token_by_user_federated_identity_credential( headers["X-AnchorMailbox"] = "Oid:{}@{}".format( user_object_id, self.authority.tenant) _data = kwargs.pop("data", {}) - _stash_client_claims(client_claims, _data) + _stash_client_claims(forwarded_client_claims, _data) response = _clean_up(self.client.obtain_token_by_user_fic( scope=self._decorate_scope(scopes), assertion=assertion, diff --git a/msal/managed_identity.py b/msal/managed_identity.py index bc452a0d..e4e88b77 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -29,7 +29,7 @@ class ManagedIdentityError(ValueError): _XMS_AZ_NWPERIMID = "xms_az_nwperimid" _CLIENT_CLAIMS_UNSUPPORTED_SOURCE = ( - "client_claims is only supported for the IMDS (Azure VM) managed identity " + "forwarded_client_claims is only supported for the IMDS (Azure VM) managed identity " "source. The detected source ({source}) does not support forwarding " "client-originated claims.") @@ -269,7 +269,7 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, - client_claims: Optional[str] = None, + forwarded_client_claims: Optional[str] = None, ): """Acquire token for the managed identity. @@ -289,15 +289,15 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. - :param client_claims: + :param forwarded_client_claims: Optional. A string representation of a JSON object containing *client-originated* claims to forward to the identity endpoint (for example a network security perimeter ``xms_az_nwperimid`` claim). Unlike ``claims_challenge`` (server-issued, which bypasses the cache), - tokens acquired with ``client_claims`` **are cached**, and the cache - entry is keyed on the claims value. Different ``client_claims`` values + tokens acquired with ``forwarded_client_claims`` **are cached**, and the cache + entry is keyed on the claims value. Different ``forwarded_client_claims`` values produce separate cache entries, so use stable, non-dynamic values to avoid unbounded cache growth. @@ -319,17 +319,17 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() - if client_claims is not None: - if not isinstance(client_claims, str): + if forwarded_client_claims is not None: + if not isinstance(forwarded_client_claims, str): raise ValueError( - "client_claims must be a string, got {}".format( - type(client_claims).__name__)) - _parse_claims_or_raise(client_claims) # Fail fast on malformed JSON + "forwarded_client_claims must be a string, got {}".format( + type(forwarded_client_claims).__name__)) + _parse_claims_or_raise(forwarded_client_claims) # Fail fast on malformed JSON # Client-originated claims isolate the cache: a distinct claims value gets # a distinct cache entry. (Server-issued claims_challenge, by contrast, # bypasses the cache and is keyed normally.) ext_cache_key = _compute_ext_cache_key( - {"client_claims": client_claims}) if client_claims else None + {"client_claims": forwarded_client_claims}) if forwarded_client_claims else None if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( @@ -371,7 +371,7 @@ def acquire_token_for_client( access_token_to_refresh.encode("utf-8")).hexdigest() if access_token_to_refresh else None, client_capabilities=self._client_capabilities, - client_claims=client_claims, + client_claims=forwarded_client_claims, ) if "access_token" in result: expires_in = result.get("expires_in", 3600) @@ -384,7 +384,7 @@ def acquire_token_for_client( self.__instance, self._tenant), response=result, params={}, - data={"client_claims": client_claims} if client_claims else {}, + data={"client_claims": forwarded_client_claims} if forwarded_client_claims else {}, )) if "refresh_in" in result: result["refresh_on"] = int(now + result["refresh_in"]) @@ -534,7 +534,7 @@ def _validate_msiv1_claims(client_claims): raise ManagedIdentityError( "MSIv1 (IMDS v1) only supports the `{expected}` custom claim. " "The claims JSON contained the unsupported key `{actual}`. " - "Remove all keys other than `{expected}` when using client_claims " + "Remove all keys other than `{expected}` when using forwarded_client_claims " "with MSIv1.".format(expected=_XMS_AZ_NWPERIMID, actual=key)) diff --git a/tests/test_application.py b/tests/test_application.py index 5c3bd8cc..053d9ef0 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -968,7 +968,7 @@ def test_fmi_token_does_not_interfere_with_non_fmi_token(self): @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestAcquireTokenForClientWithClientClaims(unittest.TestCase): - """acquire_token_for_client(client_claims=...) forwards client-originated claims + """acquire_token_for_client(forwarded_client_claims=...) forwards client-originated claims via the OAuth "claims" body parameter, caches the result, and keys the cache entry on the claims value.""" @@ -984,15 +984,15 @@ def test_client_claims_rejects_non_string_types(self): app = self._build_app() for bad_value in [123, True, ["claims"], {"a": "b"}, b"bytes"]: with self.assertRaises(ValueError, - msg="client_claims={!r} should raise".format(bad_value)): - app.acquire_token_for_client(["scope"], client_claims=bad_value) + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], forwarded_client_claims=bad_value) def test_client_claims_rejects_invalid_json(self): app = self._build_app() for bad_value in ["not json", "[1, 2]", "null", "123"]: with self.assertRaises(ValueError, - msg="client_claims={!r} should raise".format(bad_value)): - app.acquire_token_for_client(["scope"], client_claims=bad_value) + msg="forwarded_client_claims={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], forwarded_client_claims=bad_value) def test_client_claims_sent_as_claims_on_the_wire(self): app = self._build_app() @@ -1004,7 +1004,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "access_token": "an AT", "expires_in": 3600})) result = app.acquire_token_for_client( - ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) self.assertIn("access_token", result) # The client claims are forwarded via the standard OAuth "claims" parameter self.assertIn("claims", captured_data) @@ -1025,7 +1025,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "access_token": "an AT", "expires_in": 3600})) app.acquire_token_for_client( - ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) merged = json.loads(captured_data["claims"]) self.assertEqual( { @@ -1036,6 +1036,33 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "client_claims must merge with capability-derived claims") self.assertNotIn("client_claims", captured_data) + def test_forwarded_client_claims_merged_with_claims_challenge(self): + # All three claim sources -- the server-issued claims_challenge, client + # capabilities, and forwarded_client_claims -- must combine into the + # single OAuth "claims" parameter that is sent on the wire. + app = self._build_app(client_capabilities=["CP1"]) + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + challenge = '{"access_token": {"nbf": {"essential": true, "value": "1601000000"}}}' + app.acquire_token_for_client( + ["scope"], claims_challenge=challenge, + forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) + merged = json.loads(captured_data["claims"]) + self.assertEqual( + { + "nbf": {"essential": True, "value": "1601000000"}, + "xms_cc": {"values": ["CP1"]}, + "xms_az_nwperimid": {"essential": True}, + }, + merged["access_token"], + "claims_challenge, capabilities, and forwarded_client_claims must all merge") + self.assertNotIn("client_claims", captured_data) + def test_same_client_claims_returns_cached_token(self): app = self._build_app() call_count = [0] @@ -1046,10 +1073,10 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "access_token": "an AT", "expires_in": 3600})) result1 = app.acquire_token_for_client( - ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) self.assertEqual(result1[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP) result2 = app.acquire_token_for_client( - ["scope"], client_claims=self._CLIENT_CLAIMS, post=mock_post) + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=mock_post) self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, "Same client_claims should return token from cache") self.assertEqual(1, call_count[0], "Second call should not hit the IdP") @@ -1067,24 +1094,24 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): claims_b = '{"access_token": {"xms_az_nwperimid": {"values": ["B"]}}}' result_a = app.acquire_token_for_client( - ["scope"], client_claims=claims_a, post=mock_post_factory("AT_A")) + ["scope"], forwarded_client_claims=claims_a, post=mock_post_factory("AT_A")) self.assertEqual("AT_A", result_a["access_token"]) result_b = app.acquire_token_for_client( - ["scope"], client_claims=claims_b, post=mock_post_factory("AT_B")) + ["scope"], forwarded_client_claims=claims_b, post=mock_post_factory("AT_B")) self.assertEqual("AT_B", result_b["access_token"]) self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, "Different client_claims must NOT share a cache entry") result_a2 = app.acquire_token_for_client( - ["scope"], client_claims=claims_a, post=mock_post_factory("unused")) + ["scope"], forwarded_client_claims=claims_a, post=mock_post_factory("unused")) self.assertEqual("AT_A", result_a2["access_token"]) self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) def test_client_claims_token_does_not_interfere_with_plain_token(self): app = self._build_app() app.acquire_token_for_client( - ["scope"], client_claims=self._CLIENT_CLAIMS, + ["scope"], forwarded_client_claims=self._CLIENT_CLAIMS, post=lambda url, **kwargs: MinimalResponse( status_code=200, text=json.dumps({ "access_token": "claims_AT", "expires_in": 3600}))) @@ -1117,7 +1144,7 @@ def _build_user_token_response( @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestAcquireTokenOnBehalfOfWithClientClaims(unittest.TestCase): - """acquire_token_on_behalf_of(client_claims=...) forwards client-originated + """acquire_token_on_behalf_of(forwarded_client_claims=...) forwards client-originated claims via the OAuth "claims" parameter and isolates the cached token.""" def _build_app(self, **kwargs): @@ -1129,9 +1156,9 @@ def test_client_claims_rejects_invalid_values(self): app = self._build_app() for bad_value in [123, True, ["claims"], b"bytes", "not json", "null", "[1,2]"]: with self.assertRaises(ValueError, - msg="client_claims={!r} should raise".format(bad_value)): + msg="forwarded_client_claims={!r} should raise".format(bad_value)): app.acquire_token_on_behalf_of( - "assertion", ["s"], client_claims=bad_value) + "assertion", ["s"], forwarded_client_claims=bad_value) def test_client_claims_sent_as_claims_on_the_wire(self): app = self._build_app() @@ -1143,7 +1170,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "access_token": "an AT", "expires_in": 3600})) app.acquire_token_on_behalf_of( - "assertion", ["s"], client_claims=_CLIENT_CLAIMS, post=mock_post) + "assertion", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) self.assertIn("claims", captured_data) self.assertEqual( {"access_token": {"xms_az_nwperimid": {"essential": True}}}, @@ -1161,7 +1188,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "access_token": "an AT", "expires_in": 3600})) app.acquire_token_on_behalf_of( - "assertion", ["s"], client_claims=_CLIENT_CLAIMS, post=mock_post) + "assertion", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) self.assertEqual( { "xms_cc": {"values": ["CP1"]}, @@ -1172,19 +1199,19 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): def test_cached_token_is_isolated_by_client_claims(self): app = self._build_app() app.acquire_token_on_behalf_of( - "assertion", ["s"], client_claims=_CLIENT_CLAIMS, + "assertion", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=lambda *a, **k: MinimalResponse(status_code=200, text=_build_user_token_response(access_token="obo_at"))) accounts = app.get_accounts() self.assertTrue(accounts, "OBO response should create an account") hit = app.acquire_token_silent( - ["s"], accounts[0], client_claims=_CLIENT_CLAIMS) + ["s"], accounts[0], forwarded_client_claims=_CLIENT_CLAIMS) self.assertIsNotNone(hit) self.assertEqual("obo_at", hit["access_token"]) self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) self.assertIsNone( app.acquire_token_silent( - ["s"], accounts[0], client_claims=_OTHER_CLIENT_CLAIMS), + ["s"], accounts[0], forwarded_client_claims=_OTHER_CLIENT_CLAIMS), "Different client_claims must not read the cached token") self.assertIsNone( app.acquire_token_silent(["s"], accounts[0]), @@ -1193,7 +1220,7 @@ def test_cached_token_is_isolated_by_client_claims(self): @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestAcquireTokenByAuthorizationCodeWithClientClaims(unittest.TestCase): - """acquire_token_by_authorization_code(client_claims=...) forwards + """acquire_token_by_authorization_code(forwarded_client_claims=...) forwards client-originated claims and isolates the cached token.""" def _build_app(self, **kwargs): @@ -1205,9 +1232,9 @@ def test_client_claims_rejects_invalid_values(self): app = self._build_app() for bad_value in [123, ["claims"], b"bytes", "not json", "null"]: with self.assertRaises(ValueError, - msg="client_claims={!r} should raise".format(bad_value)): + msg="forwarded_client_claims={!r} should raise".format(bad_value)): app.acquire_token_by_authorization_code( - "code", ["s"], client_claims=bad_value) + "code", ["s"], forwarded_client_claims=bad_value) def test_client_claims_sent_as_claims_on_the_wire(self): app = self._build_app() @@ -1219,7 +1246,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "access_token": "an AT", "expires_in": 3600})) app.acquire_token_by_authorization_code( - "code", ["s"], client_claims=_CLIENT_CLAIMS, post=mock_post) + "code", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) self.assertIn("claims", captured_data) self.assertEqual( {"access_token": {"xms_az_nwperimid": {"essential": True}}}, @@ -1230,13 +1257,13 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): def test_cached_token_is_isolated_by_client_claims(self): app = self._build_app() app.acquire_token_by_authorization_code( - "code", ["s"], client_claims=_CLIENT_CLAIMS, + "code", ["s"], forwarded_client_claims=_CLIENT_CLAIMS, post=lambda *a, **k: MinimalResponse(status_code=200, text=_build_user_token_response(access_token="authcode_at"))) accounts = app.get_accounts() self.assertTrue(accounts) hit = app.acquire_token_silent( - ["s"], accounts[0], client_claims=_CLIENT_CLAIMS) + ["s"], accounts[0], forwarded_client_claims=_CLIENT_CLAIMS) self.assertIsNotNone(hit) self.assertEqual("authcode_at", hit["access_token"]) self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) @@ -1247,7 +1274,7 @@ def test_cached_token_is_isolated_by_client_claims(self): @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestUserFicWithClientClaims(unittest.TestCase): - """acquire_token_by_user_federated_identity_credential(client_claims=...) + """acquire_token_by_user_federated_identity_credential(forwarded_client_claims=...) forwards client-originated claims and isolates the cached token.""" def _build_app(self, **kwargs): @@ -1259,10 +1286,10 @@ def test_client_claims_rejects_invalid_values(self): app = self._build_app() for bad_value in [123, ["claims"], b"bytes", "not json", "null"]: with self.assertRaises(ValueError, - msg="client_claims={!r} should raise".format(bad_value)): + msg="forwarded_client_claims={!r} should raise".format(bad_value)): app.acquire_token_by_user_federated_identity_credential( ["s"], assertion="t2", username="user@contoso.com", - client_claims=bad_value) + forwarded_client_claims=bad_value) def test_client_claims_sent_as_claims_on_the_wire(self): app = self._build_app() @@ -1275,7 +1302,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): app.acquire_token_by_user_federated_identity_credential( ["s"], assertion="t2", username="user@contoso.com", - client_claims=_CLIENT_CLAIMS, post=mock_post) + forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) self.assertIn("claims", captured_data) self.assertEqual( {"access_token": {"xms_az_nwperimid": {"essential": True}}}, @@ -1287,14 +1314,14 @@ def test_cached_token_is_isolated_by_client_claims(self): app = self._build_app() app.acquire_token_by_user_federated_identity_credential( ["s"], assertion="t2", username="user@contoso.com", - client_claims=_CLIENT_CLAIMS, + forwarded_client_claims=_CLIENT_CLAIMS, post=lambda *a, **k: MinimalResponse(status_code=200, text=_build_user_token_response( access_token="fic_at", client_id="agent_app_id"))) accounts = app.get_accounts() self.assertTrue(accounts) hit = app.acquire_token_silent( - ["s"], accounts[0], client_claims=_CLIENT_CLAIMS) + ["s"], accounts[0], forwarded_client_claims=_CLIENT_CLAIMS) self.assertIsNotNone(hit) self.assertEqual("fic_at", hit["access_token"]) self.assertEqual(hit[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE) @@ -1305,7 +1332,7 @@ def test_cached_token_is_isolated_by_client_claims(self): @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestAcquireTokenSilentWithClientClaims(unittest.TestCase): - """acquire_token_silent(client_claims=...) isolates cache reads and merges + """acquire_token_silent(forwarded_client_claims=...) isolates cache reads and merges the claims into the refresh-token request sent on the wire.""" def _build_app(self, **kwargs): @@ -1335,7 +1362,7 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): result = app.acquire_token_silent( ["s"], account, force_refresh=True, - client_claims=_CLIENT_CLAIMS, post=mock_post) + forwarded_client_claims=_CLIENT_CLAIMS, post=mock_post) self.assertIsNotNone(result) self.assertEqual("refresh_token", captured_data.get("grant_type")) self.assertIn("claims", captured_data) @@ -1351,10 +1378,10 @@ def test_both_silent_entry_points_validate_client_claims(self): for bad_value in [123, ["claims"], "not json", "null"]: with self.assertRaises(ValueError): app.acquire_token_silent( - ["s"], account, client_claims=bad_value) + ["s"], account, forwarded_client_claims=bad_value) with self.assertRaises(ValueError): app.acquire_token_silent_with_error( - ["s"], account, client_claims=bad_value) + ["s"], account, forwarded_client_claims=bad_value) @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) diff --git a/tests/test_mi.py b/tests/test_mi.py index d1f5c572..b75aaf20 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -277,7 +277,7 @@ def _mock_get(self, token="AT"): def test_client_claims_sent_as_claims_query_param(self): with self._mock_get() as mock_get: result = self.app.acquire_token_for_client( - resource="R", client_claims=self._CLAIMS) + resource="R", forwarded_client_claims=self._CLAIMS) self.assertIn("access_token", result) self.assertEqual( self._CLAIMS, mock_get.call_args.kwargs["params"].get("claims"), @@ -292,12 +292,12 @@ def test_msiv1_rejects_non_nwperimid_claim(self): with self.assertRaises(ManagedIdentityError): self.app.acquire_token_for_client( resource="R", - client_claims='{"some_other_claim": {"essential": true}}') + forwarded_client_claims='{"some_other_claim": {"essential": true}}') def test_invalid_json_claims_raises(self): for bad in ["not json", "[1, 2]", "null"]: with self.assertRaises(ValueError, msg="{!r} should raise".format(bad)): - self.app.acquire_token_for_client(resource="R", client_claims=bad) + self.app.acquire_token_for_client(resource="R", forwarded_client_claims=bad) def test_non_string_client_claims_raises_value_error(self): # A non-str client_claims (int, bytes, dict, ...) must raise a friendly @@ -306,15 +306,15 @@ def test_non_string_client_claims_raises_value_error(self): for bad in [123, b'{"xms_az_nwperimid": {}}', {"xms_az_nwperimid": {}}]: with self.assertRaises( ValueError, msg="{!r} should raise ValueError".format(bad)): - self.app.acquire_token_for_client(resource="R", client_claims=bad) + self.app.acquire_token_for_client(resource="R", forwarded_client_claims=bad) def test_same_client_claims_hits_cache(self): with self._mock_get("AT1") as mock_get: r1 = self.app.acquire_token_for_client( - resource="R", client_claims=self._CLAIMS) + resource="R", forwarded_client_claims=self._CLAIMS) self.assertEqual("identity_provider", r1["token_source"]) r2 = self.app.acquire_token_for_client( - resource="R", client_claims=self._CLAIMS) + resource="R", forwarded_client_claims=self._CLAIMS) self.assertEqual("cache", r2["token_source"], "Should hit cache") self.assertEqual(1, mock_get.call_count, "Second call must not hit IMDS") @@ -323,24 +323,24 @@ def test_different_client_claims_are_cached_separately(self): claims_b = '{"xms_az_nwperimid": {"values": ["B"]}}' with self._mock_get("AT_A"): ra = self.app.acquire_token_for_client( - resource="R", client_claims=claims_a) + resource="R", forwarded_client_claims=claims_a) self.assertEqual("AT_A", ra["access_token"]) with self._mock_get("AT_B"): rb = self.app.acquire_token_for_client( - resource="R", client_claims=claims_b) + resource="R", forwarded_client_claims=claims_b) self.assertEqual("AT_B", rb["access_token"]) self.assertEqual("identity_provider", rb["token_source"], "Different client_claims must NOT share a cache entry") with self._mock_get("unused"): ra2 = self.app.acquire_token_for_client( - resource="R", client_claims=claims_a) + resource="R", forwarded_client_claims=claims_a) self.assertEqual("AT_A", ra2["access_token"]) self.assertEqual("cache", ra2["token_source"]) def test_plain_request_does_not_return_client_claims_token(self): with self._mock_get("claims_AT"): self.app.acquire_token_for_client( - resource="R", client_claims=self._CLAIMS) + resource="R", forwarded_client_claims=self._CLAIMS) with self._mock_get("plain_AT"): result = self.app.acquire_token_for_client(resource="R") self.assertEqual("plain_AT", result["access_token"]) @@ -391,7 +391,7 @@ def test_app_service_resource_id_parameter_should_be_mi_res_id(self): def test_client_claims_not_supported_on_app_service(self): with self.assertRaises(ManagedIdentityError): self.app.acquire_token_for_client( - resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') @patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) @@ -421,7 +421,7 @@ def test_machine_learning_error_should_be_normalized(self): def test_client_claims_not_supported_on_machine_learning(self): with self.assertRaises(ManagedIdentityError): self.app.acquire_token_for_client( - resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') @patch.dict(os.environ, { @@ -497,7 +497,7 @@ def test_sf_error_should_be_normalized(self): def test_client_claims_not_supported_on_service_fabric(self): with self.assertRaises(ManagedIdentityError): self.app.acquire_token_for_client( - resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') @patch.dict(os.environ, { @@ -555,7 +555,7 @@ def test_arc_error_should_be_normalized(self, mocked_stat): def test_client_claims_not_supported_on_arc(self, mocked_stat): with self.assertRaises(ManagedIdentityError): self.app.acquire_token_for_client( - resource="R", client_claims='{"xms_az_nwperimid": {"essential": true}}') + resource="R", forwarded_client_claims='{"xms_az_nwperimid": {"essential": true}}') class GetManagedIdentitySourceTestCase(unittest.TestCase): From 4fc3639226535083b0b2137cd4615cf047f22556 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Tue, 30 Jun 2026 20:00:15 -0400 Subject: [PATCH 5/9] Harden forwarded_client_claims (cross-MSAL review fixes) Apply six fixes derived from the sibling MSAL PRs (go #629, java #1039, js #8686) to the forwarded_client_claims port: 1. Make _compute_ext_cache_key injective. Switch from separator-less key+value concatenation to length-prefixed pairs ("{len(k)}:{k}{len(v)}:{v}"), matching Go's post-collision-fix CacheExtKeyGenerator. Without this, fmi_path + client_claims (which now co-occur in acquire_token_for_client) could collide and return the wrong cached token. Adds boundary-collision regression tests. NOTE: hashes are now intentionally not byte-identical to MSAL .NET (which still uses unprefixed concat); caches are not shared across languages, so within-process injectivity is what matters. 2. Remove the MSIv1 client-side allow-list (_validate_msiv1_claims). Forward any JSON-object claims value as-is and let IMDS decide which keys it accepts, matching go/java. 3. Validate the managed-identity source before the cache read. Reject unsupported sources (Service Fabric, App Service, Machine Learning, Azure Arc) up front so an unsupported source never returns a cached client-claims token. _obtain_token keeps its per-source guards as a backstop. 4. Add merge-conflict precedence tests: on a direct leaf conflict the client-originated value wins (merged last); disjoint claims are preserved. 5. Drop the first-party xms_az_nwperimid example from public docstrings; use generic "client-originated claims" wording. 6. Document that the same forwarded_client_claims value must be sent on every request that should share the cached token (it is part of the cache key). 204 tests pass across test_token_cache.py, test_mi.py, test_application.py. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/application.py | 59 ++++++++++---------- msal/managed_identity.py | 61 ++++++++++++--------- msal/token_cache.py | 22 ++++++-- tests/test_application.py | 27 ++++++++- tests/test_mi.py | 17 ++++-- tests/test_token_cache.py | 112 ++++++++++++++++++++++++++++---------- 6 files changed, 203 insertions(+), 95 deletions(-) diff --git a/msal/application.py b/msal/application.py index 74dd4dc8..f8ed857a 100644 --- a/msal/application.py +++ b/msal/application.py @@ -67,8 +67,8 @@ def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge): def _stash_client_claims(forwarded_client_claims, data): """Validate ``forwarded_client_claims`` and stash it into the request ``data``. - ``forwarded_client_claims`` carries *client-originated* claims (for example a - network security perimeter ``xms_az_nwperimid`` claim). The raw value is + ``forwarded_client_claims`` carries *client-originated* claims supplied by the + caller. The raw value is stored in ``data`` (under the internal ``client_claims`` key) so that it (a) contributes to the extended cache key -- isolating cache entries by claims value -- and (b) is stripped from the request body by the oauth2 @@ -1303,11 +1303,12 @@ def acquire_token_by_authorization_code( returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. :param str forwarded_client_claims: - Optional. A JSON string of *client-originated* claims (for example - a network security perimeter ``xms_az_nwperimid`` claim) to include - in the token request. Unlike ``claims_challenge`` (server-issued, - which bypasses the cache), tokens acquired with ``forwarded_client_claims`` - **are cached** and keyed on the claims value, so use stable, + Optional. A JSON string of *client-originated* claims to include in + the token request. Unlike ``claims_challenge`` (server-issued, which + bypasses the cache), tokens acquired with ``forwarded_client_claims`` + **are cached** and keyed on the claims value. Send the *same* value on + every request that should share the cached token; omitting or changing + it routes to a different cache entry (a cache miss), so use stable, non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. @@ -1588,11 +1589,11 @@ def acquire_token_silent_with_error( returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. :param str forwarded_client_claims: - Optional. A JSON string of *client-originated* claims (for example - a network security perimeter ``xms_az_nwperimid`` claim) to include - when a cached token is missing and a network request is made. Tokens - are **cached** and keyed on the claims value (different values yield - separate cache entries), so use stable, non-dynamic values. + Optional. A JSON string of *client-originated* claims to include when + a cached token is missing and a network request is made. Tokens are + **cached** and keyed on the claims value (different values yield + separate cache entries), so send the *same* value on every call that + should reuse the cached token, and use stable, non-dynamic values. Not to be confused with the constructor ``client_claims`` parameter (a ``dict`` of extra claims signed into the client-assertion JWT). @@ -2592,16 +2593,16 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, ) :param str forwarded_client_claims: Optional. A JSON string containing *client-originated* claims to - include in the token request (for example a network security - perimeter ``xms_az_nwperimid`` claim). + include in the token request. Unlike ``claims_challenge`` (which carries *server-issued* claims challenges and bypasses the cache), tokens acquired with ``forwarded_client_claims`` **are cached**, and the cache entry is keyed on the - claims value. Different ``forwarded_client_claims`` values produce separate - cache entries, so use stable, non-dynamic values to avoid unbounded - cache growth. The value is merged into the standard OAuth ``claims`` - request parameter sent on the wire. + claims value. Send the *same* value on every request that should share + the cached token; different values produce separate cache entries, so + use stable, non-dynamic values to avoid unbounded cache growth. The + value is merged into the standard OAuth ``claims`` request parameter + sent on the wire. Not to be confused with the constructor ``client_claims`` parameter (a ``dict`` of extra claims signed into the client-assertion JWT). @@ -2699,11 +2700,12 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. :param str forwarded_client_claims: - Optional. A JSON string of *client-originated* claims (for example - a network security perimeter ``xms_az_nwperimid`` claim) to include - in the token request. Unlike ``claims_challenge`` (server-issued, - which bypasses the cache), tokens acquired with ``forwarded_client_claims`` - **are cached** and keyed on the claims value, so use stable, + Optional. A JSON string of *client-originated* claims to include in + the token request. Unlike ``claims_challenge`` (server-issued, which + bypasses the cache), tokens acquired with ``forwarded_client_claims`` + **are cached** and keyed on the claims value. Send the *same* value on + every request that should share the cached token; omitting or changing + it routes to a different cache entry (a cache miss), so use stable, non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. @@ -2770,11 +2772,12 @@ def acquire_token_by_user_federated_identity_credential( returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. :param str forwarded_client_claims: - Optional. A JSON string of *client-originated* claims (for example - a network security perimeter ``xms_az_nwperimid`` claim) to include - in the token request. Unlike ``claims_challenge`` (server-issued, - which bypasses the cache), tokens acquired with ``forwarded_client_claims`` - **are cached** and keyed on the claims value, so use stable, + Optional. A JSON string of *client-originated* claims to include in + the token request. Unlike ``claims_challenge`` (server-issued, which + bypasses the cache), tokens acquired with ``forwarded_client_claims`` + **are cached** and keyed on the claims value. Send the *same* value on + every request that should share the cached token; omitting or changing + it routes to a different cache entry (a cache miss), so use stable, non-dynamic values. The value is merged into the standard OAuth ``claims`` request parameter sent on the wire. diff --git a/msal/managed_identity.py b/msal/managed_identity.py index e4e88b77..de10e535 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -26,8 +26,6 @@ class ManagedIdentityError(ValueError): pass -_XMS_AZ_NWPERIMID = "xms_az_nwperimid" - _CLIENT_CLAIMS_UNSUPPORTED_SOURCE = ( "forwarded_client_claims is only supported for the IMDS (Azure VM) managed identity " "source. The detected source ({source}) does not support forwarding " @@ -292,18 +290,17 @@ def acquire_token_for_client( :param forwarded_client_claims: Optional. A string representation of a JSON object containing - *client-originated* claims to forward to the identity endpoint - (for example a network security perimeter ``xms_az_nwperimid`` claim). + *client-originated* claims to forward to the identity endpoint. Unlike ``claims_challenge`` (server-issued, which bypasses the cache), tokens acquired with ``forwarded_client_claims`` **are cached**, and the cache - entry is keyed on the claims value. Different ``forwarded_client_claims`` values - produce separate cache entries, so use stable, non-dynamic values to - avoid unbounded cache growth. + entry is keyed on the claims value. Send the *same* value on every + request that should share the cached token; different values produce + separate cache entries, so use stable, non-dynamic values to avoid + unbounded cache growth. Only the IMDS (Azure VM) managed identity source supports this - parameter; other sources raise an error. On IMDS v1, the claims JSON - may contain only the ``xms_az_nwperimid`` key. + parameter; other sources raise an error. .. note:: @@ -325,6 +322,9 @@ def acquire_token_for_client( "forwarded_client_claims must be a string, got {}".format( type(forwarded_client_claims).__name__)) _parse_claims_or_raise(forwarded_client_claims) # Fail fast on malformed JSON + # Reject unsupported sources before any cache read, so an unsupported + # source never returns a cached client-claims token. + _raise_if_claims_unsupported_source() # Client-originated claims isolate the cache: a distinct claims value gets # a distinct cache entry. (Server-issued claims_challenge, by contrast, # bypasses the cache and is keyed normally.) @@ -447,6 +447,29 @@ def get_managed_identity_source(): return DEFAULT_TO_VM +# Managed-identity sources that cannot forward client-originated claims. Keep in +# sync with the per-source guards inside _obtain_token (the backstop). Cloud Shell +# is intentionally absent: it falls through to the Azure VM / IMDS path, which +# does support claims. +_CLIENT_CLAIMS_UNSUPPORTED_SOURCES = { + SERVICE_FABRIC: "Service Fabric", + APP_SERVICE: "App Service", + MACHINE_LEARNING: "Machine Learning", + AZURE_ARC: "Azure Arc", +} + + +def _raise_if_claims_unsupported_source(): + """Fail fast -- before any cache read -- when the detected managed-identity + source cannot forward client-originated claims. ``_obtain_token`` enforces the + same rule per source as a backstop, but validating up front avoids a cache + lookup (and returning a cached token) for an unsupported source.""" + name = _CLIENT_CLAIMS_UNSUPPORTED_SOURCES.get(get_managed_identity_source()) + if name: + raise ManagedIdentityError( + _CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source=name)) + + def _obtain_token( http_client, managed_identity, resource, *, @@ -521,22 +544,6 @@ def _adjust_param(params, managed_identity, types_mapping=None): if id_name: params[id_name] = managed_identity[ManagedIdentity.ID] -def _validate_msiv1_claims(client_claims): - """MSIv1 (IMDS v1) only supports the single ``xms_az_nwperimid`` custom claim. - - Any other top-level key makes IMDS return HTTP 400 with no useful diagnostic, - so validate early and raise a clear error. Mirrors MSAL .NET's - ``AbstractManagedIdentity.ValidateMsiv1Claims``. - """ - parsed = _parse_claims_or_raise(client_claims) - for key in parsed: - if key != _XMS_AZ_NWPERIMID: - raise ManagedIdentityError( - "MSIv1 (IMDS v1) only supports the `{expected}` custom claim. " - "The claims JSON contained the unsupported key `{actual}`. " - "Remove all keys other than `{expected}` when using forwarded_client_claims " - "with MSIv1.".format(expected=_XMS_AZ_NWPERIMID, actual=key)) - def _obtain_token_on_azure_vm(http_client, managed_identity, resource, client_claims=None): # Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http @@ -547,8 +554,8 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource, client_cl } _adjust_param(params, managed_identity) if client_claims: - # IMDS v1 (MSIv1) only supports the single xms_az_nwperimid claim. - _validate_msiv1_claims(client_claims) + # Forward client-originated claims as-is; IMDS decides which keys it + # accepts (no client-side allow-list, matching the other MSALs). params["claims"] = client_claims # http_client.get url-encodes query params resp = http_client.get( os.getenv( diff --git a/msal/token_cache.py b/msal/token_cache.py index 7bc21a43..c6b4a70a 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -83,8 +83,15 @@ def _compute_ext_cache_key(data): Returns an empty string when *data* has no hashable fields. - The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator): - sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded. + The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator, + post-collision-fix): length-prefixed key/value pairs (sorted by key) are + concatenated and SHA256 hashed, then base64url encoded. The length prefixes + make the encoding injective, so two distinct component sets can never collide + onto the same cache key. (MSAL .NET's ``ComputeAccessTokenExtCacheKey`` still + uses an unprefixed concatenation, so the hash is intentionally not + byte-identical to current .NET; the cache *key format* still matches both. + Caches are not shared across languages, so this only affects within-process + isolation, where injectivity is what matters.) """ if not data: return "" @@ -94,9 +101,16 @@ def _compute_ext_cache_key(data): } if not cache_components: return "" - # Sort keys for consistent hashing (matches Go implementation) + # Concatenate length-prefixed key/value pairs so component boundaries are + # unambiguous (matches Go's CacheExtKeyGenerator). A plain key+value + # concatenation with no separators can collide when one value happens to + # contain another component's key or value -- and client_claims is arbitrary + # caller-supplied JSON that may embed e.g. "fmi_path" at a boundary -- mapping + # two distinct component sets onto the same hash and returning the wrong + # cached token. Length prefixes make the encoding injective. key_str = "".join( - k + cache_components[k] for k in sorted(cache_components.keys()) + "{}:{}{}:{}".format(len(k), k, len(v), v) + for k, v in sorted(cache_components.items()) ) hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest() return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() diff --git a/tests/test_application.py b/tests/test_application.py index 053d9ef0..a5d4ca8b 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1063,7 +1063,32 @@ def mock_post(url, headers=None, data=None, *args, **kwargs): "claims_challenge, capabilities, and forwarded_client_claims must all merge") self.assertNotIn("client_claims", captured_data) - def test_same_client_claims_returns_cached_token(self): + def test_forwarded_client_claims_win_on_leaf_conflict_with_challenge(self): + # If the server-issued claims_challenge and forwarded_client_claims set + # the SAME claim, the client-originated value wins (it is merged in last), + # while disjoint claims from the challenge are preserved. Documents the + # conflict-resolution behavior the other MSAL reviewers asked about. + app = self._build_app() + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "an AT", "expires_in": 3600})) + + challenge = ('{"access_token": {"acrs": {"values": ["server"]},' + ' "nbf": {"essential": true}}}') + client = '{"access_token": {"acrs": {"values": ["client"]}}}' + app.acquire_token_for_client( + ["scope"], claims_challenge=challenge, + forwarded_client_claims=client, post=mock_post) + merged = json.loads(captured_data["claims"])["access_token"] + self.assertEqual( + {"values": ["client"]}, merged["acrs"], + "forwarded_client_claims must win a direct leaf conflict") + self.assertEqual( + {"essential": True}, merged["nbf"], + "Disjoint claims from the challenge must be preserved") app = self._build_app() call_count = [0] diff --git a/tests/test_mi.py b/tests/test_mi.py index b75aaf20..f7082aa2 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -288,11 +288,18 @@ def test_no_claims_param_when_client_claims_absent(self): self.app.acquire_token_for_client(resource="R") self.assertNotIn("claims", mock_get.call_args.kwargs["params"]) - def test_msiv1_rejects_non_nwperimid_claim(self): - with self.assertRaises(ManagedIdentityError): - self.app.acquire_token_for_client( - resource="R", - forwarded_client_claims='{"some_other_claim": {"essential": true}}') + def test_non_nwperimid_claim_is_forwarded_not_rejected(self): + # MSAL no longer enforces a client-side allow-list (matching the other + # MSALs); any JSON-object claims value is forwarded as-is and IMDS decides + # which keys it accepts. + other = '{"some_other_claim": {"essential": true}}' + with self._mock_get() as mock_get: + result = self.app.acquire_token_for_client( + resource="R", forwarded_client_claims=other) + self.assertIn("access_token", result) + self.assertEqual( + other, mock_get.call_args.kwargs["params"].get("claims"), + "Non-nwperimid claims must be forwarded, not rejected") def test_invalid_json_claims_raises(self): for bad in ["not json", "[1, 2]", "null"]: diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 149d0aba..bd4ca840 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -392,6 +392,29 @@ def test_different_client_claims_produce_different_hashes(self): def test_empty_client_claims_value_is_ignored(self): self.assertEqual("", _compute_ext_cache_key({"client_claims": ""})) + def test_length_prefixed_encoding_avoids_boundary_collision(self): + # Mirrors Go's TestCacheKeyComponentHashNoBoundaryCollision. With a plain + # key+value concatenation (no separators) these two distinct component + # sets both render to "axbYbZ" (sorted keys "a","b") and would collide, + # returning the wrong cached token. The length-prefixed encoding must keep + # them distinct. This matters because client_claims is arbitrary caller + # JSON that can embed another component's key (e.g. "fmi_path"). + h1 = _compute_ext_cache_key({"a": "xbY", "b": "Z"}) + h2 = _compute_ext_cache_key({"a": "x", "b": "YbZ"}) + self.assertNotEqual( + h1, h2, + "distinct cache key components must not produce the same hash") + + def test_client_claims_and_fmi_path_do_not_collide_at_boundary(self): + # Realistic surface: client_claims and fmi_path co-occur in + # acquire_token_for_client. A claims value that happens to contain the + # other component's key+value at a boundary must not collide. + h1 = _compute_ext_cache_key( + {"client_claims": "Xfmi_pathY", "fmi_path": "Z"}) + h2 = _compute_ext_cache_key( + {"client_claims": "X", "fmi_path": "Yfmi_pathZ"}) + self.assertNotEqual(h1, h2) + class TestClaimsHelpers(unittest.TestCase): """Tests for the shared _parse_claims_or_raise / _merge_claims helpers.""" @@ -439,6 +462,26 @@ def test_merge_deep_merges_objects(self): {"xms_cc": {"values": ["cp1"]}, "xms_az_nwperimid": {"essential": True}}, merged["access_token"]) + def test_merge_leaf_conflict_second_arg_wins(self): + # When both sides set the SAME leaf to different values, the second + # argument (client-originated claims) wins. _acquire_token_for_client + # passes server-issued claims first and forwarded_client_claims second, + # so the caller's value takes precedence on a direct conflict. + merged = json.loads(_merge_claims( + '{"access_token": {"acrs": {"values": ["server"]}}}', + '{"access_token": {"acrs": {"values": ["client"]}}}')) + self.assertEqual({"values": ["client"]}, merged["access_token"]["acrs"]) + + def test_merge_nested_conflict_preserves_disjoint_siblings(self): + # A conflict on one nested key must not drop the other side's disjoint + # sibling keys at the same level. + merged = json.loads(_merge_claims( + '{"access_token": {"a": {"values": ["server"]}, "keep_server": 1}}', + '{"access_token": {"a": {"values": ["client"]}, "keep_client": 2}}')) + self.assertEqual( + {"a": {"values": ["client"]}, "keep_server": 1, "keep_client": 2}, + merged["access_token"]) + class TestExtCacheKeyIsolation(unittest.TestCase): """Tests that ext_cache_key provides proper cache isolation in TokenCache.""" @@ -522,40 +565,48 @@ def test_non_fmi_tokens_not_affected_by_fmi_cache(self): class TestCrossMsalCacheKeyCompatibility(unittest.TestCase): - """Verify that _compute_ext_cache_key produces hashes identical to MSAL Go - (CacheExtKeyGenerator) and MSAL .NET (CoreHelpers.ComputeAccessTokenExtCacheKey). + """Verify that _compute_ext_cache_key matches MSAL Go's CacheExtKeyGenerator + (post collision-fix, AzureAD/microsoft-authentication-library-for-go#629). - All three libraries use the same algorithm: + The algorithm: 1. Sort key-value pairs alphabetically by key (ordinal / case-sensitive) - 2. Concatenate them: "key1value1key2value2…" + 2. Concatenate length-prefixed pairs ("{len(k)}:{k}{len(v)}:{v}" per pair; + e.g. {"key1": "value1"} -> "4:key16:value1"). The length prefixes make + the encoding injective -- see TestClientClaimsCacheKey for the collision + guard. 3. SHA-256 hash 4. Base64url encode (no padding), lowercased - The expected hashes below are copied from: - - MSAL Go: authority_ext_cachekey_test.go (TestAppKeyWithCacheKeyComponent) - - MSAL .NET: CacheKeyExtensionTests.cs (RunHappyPathTest, CacheExtEnsurePopKeysFunctionAsync) + The expected hashes below are copied from MSAL Go's + authority_ext_cachekey_test.go (TestAppKeyWithCacheKeyComponent). + + NOTE: MSAL .NET's ComputeAccessTokenExtCacheKey still uses an *unprefixed* + concatenation, so these hashes are intentionally NOT byte-identical to current + .NET. The cache *key format* (the 'atext' segment layout, asserted below) still + matches both Go and .NET; only the trailing hash bytes differ. Caches are not + shared across languages, so within-process injectivity -- not cross-language + byte-parity -- is what matters for correctness. """ - def test_two_params_hash_matches_go_and_dotnet(self): - """Go/dotnet expected: bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi""" + def test_two_params_hash_matches_go(self): + """Go expected: latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike""" result = _compute_ext_cache_key({"key1": "value1", "key2": "value2"}) - self.assertEqual("bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", result) + self.assertEqual("latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike", result) - def test_two_different_params_hash_matches_go_and_dotnet(self): - """Go/dotnet expected: 3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u""" + def test_two_different_params_hash_matches_go(self): + """Go expected: jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq""" result = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) - self.assertEqual("3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u", result) + self.assertEqual("jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq", result) - def test_five_params_hash_matches_go_and_dotnet(self): - """Go/dotnet expected (full hash): rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e - Go test uses substring match 'gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e'.""" + def test_five_params_hash_matches_go(self): + """Go expected: prrdp31y37ufw3lo7hly0oimjjvg_34m9ji30ocu4tw""" result = _compute_ext_cache_key({ "key3": "value3", "key4": "value4", "key5": "value5", "key6": "value6", "key7": "value7", }) - self.assertEqual("rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e", result) + self.assertEqual("prrdp31y37ufw3lo7hly0oimjjvg_34m9ji30ocu4tw", result) - def test_order_independence_matches_go_and_dotnet(self): + def test_order_independence_matches_go(self): """Same keys in different insertion order must produce the same hash (mirrors TestCacheKeyComponentHashConsistency in Go).""" h1 = _compute_ext_cache_key({"key3": "value3", "key4": "value4", @@ -576,9 +627,9 @@ def test_at_cache_key_uses_atext_credential_type(self): key = key_maker( home_account_id="hid", environment="env", client_id="cid", realm="realm", target="scope", - ext_cache_key="bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi") + ext_cache_key="latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike") self.assertEqual( - "hid-env-atext-cid-realm-scope-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", + "hid-env-atext-cid-realm-scope-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike", key) def test_at_cache_key_without_ext_uses_accesstoken(self): @@ -590,9 +641,10 @@ def test_at_cache_key_without_ext_uses_accesstoken(self): realm="realm", target="scope") self.assertEqual("hid-env-accesstoken-cid-realm-scope", key) - def test_dotnet_style_full_at_cache_key(self): - """Reproduce the exact cache key from MSAL .NET CacheKeyExtensionTests: - expectedCacheKey1 = '-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi' + def test_atext_full_at_cache_key_format(self): + """The AT cache key *format* matches MSAL .NET's CacheKeyExtensionTests + layout ('-{env}-atext-{clientId}-{tenant}-{scopes}-{hash}'); only the + trailing hash now follows Go's length-prefixed encoding (see class note). """ cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] @@ -604,11 +656,11 @@ def test_dotnet_style_full_at_cache_key(self): realm="common", target="r1/scope1 r1/scope2", ext_cache_key=ext_hash) - expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike" self.assertEqual(expected, key) - def test_dotnet_style_second_cache_key(self): - """Reproduce CacheKeyExtensionTests expectedCacheKey2.""" + def test_atext_second_full_at_cache_key_format(self): + """Second key-format vector (mirrors CacheKeyExtensionTests expectedCacheKey2).""" cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] ext_hash = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) @@ -619,12 +671,12 @@ def test_dotnet_style_second_cache_key(self): realm="common", target="r1/scope1 r1/scope2", ext_cache_key=ext_hash) - expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u" + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq" self.assertEqual(expected, key) def test_go_style_at_cache_key(self): - """Reproduce the Go AccessToken.Key() format: - Go test: 'testhid-env-atext-clientid-realm-user.read-{hash}' + """Reproduce the Go AccessToken.Key() format with Go's post-#629 hash: + 'testhid-env-atext-clientid-realm-user.read-{hash}'. """ cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] @@ -636,5 +688,5 @@ def test_go_style_at_cache_key(self): realm="realm", target="user.read", ext_cache_key=ext_hash) - expected = "testhid-env-atext-clientid-realm-user.read-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" + expected = "testhid-env-atext-clientid-realm-user.read-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike" self.assertEqual(expected, key) From f8040e498596fe31ee7f1ce7498fdd07c9d4ce85 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 1 Jul 2026 12:41:12 -0400 Subject: [PATCH 6/9] Revert ext-cache-key to MSAL .NET byte-parity (undo hardening fix #1) Restore _compute_ext_cache_key to MSAL .NET's ComputeAccessTokenExtCacheKey encoding: sorted, separator-less key+value concatenation -> SHA-256 -> base64url. This makes the ext cache key byte-identical to MSAL .NET again. The earlier hardening commit (4fc3639) had switched to Go's post-#629 length-prefixed encoding to make the key injective. Per maintainer decision, msal-python should match MSAL .NET, not current Go, so that change is reverted: - token_cache.py: restore plain key+value concatenation; docstring now notes the .NET match and the deliberate divergence from Go's #629 length-prefixed form. - test_token_cache.py: restore the .NET parity hashes (bns2ytmx..., 3-rg6_wy..., rn_gkpxx...) and rename the parity tests to *_matches_dotnet; remove the two length-prefix boundary-collision regression tests (they asserted the injective property that .NET's encoding does not provide). Hardening fixes #2-#6 (MI allow-list removal, MI source pre-validation, merge conflict-precedence tests, generic docs, send-on-every-request docs) are unchanged. 202 tests pass across test_token_cache.py, test_mi.py, test_application.py. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/token_cache.py | 31 ++++++------- tests/test_token_cache.py | 94 ++++++++++++++------------------------- 2 files changed, 47 insertions(+), 78 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index c6b4a70a..1907d569 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -83,15 +83,15 @@ def _compute_ext_cache_key(data): Returns an empty string when *data* has no hashable fields. - The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator, - post-collision-fix): length-prefixed key/value pairs (sorted by key) are - concatenated and SHA256 hashed, then base64url encoded. The length prefixes - make the encoding injective, so two distinct component sets can never collide - onto the same cache key. (MSAL .NET's ``ComputeAccessTokenExtCacheKey`` still - uses an unprefixed concatenation, so the hash is intentionally not - byte-identical to current .NET; the cache *key format* still matches both. - Caches are not shared across languages, so this only affects within-process - isolation, where injectivity is what matters.) + The algorithm matches MSAL .NET's ``ComputeAccessTokenExtCacheKey``: sorted + key+value pairs are concatenated (no separators) and SHA256 hashed, then + base64url encoded. This keeps the hash byte-identical to MSAL .NET. + + MSAL Go's ``CacheExtKeyGenerator`` has since switched to a length-prefixed + encoding (AzureAD/microsoft-authentication-library-for-go#629) to make it + injective; Python deliberately tracks .NET instead, so these hashes are not + byte-identical to current Go. Caches are not shared across languages, so the + difference does not affect runtime correctness. """ if not data: return "" @@ -101,16 +101,11 @@ def _compute_ext_cache_key(data): } if not cache_components: return "" - # Concatenate length-prefixed key/value pairs so component boundaries are - # unambiguous (matches Go's CacheExtKeyGenerator). A plain key+value - # concatenation with no separators can collide when one value happens to - # contain another component's key or value -- and client_claims is arbitrary - # caller-supplied JSON that may embed e.g. "fmi_path" at a boundary -- mapping - # two distinct component sets onto the same hash and returning the wrong - # cached token. Length prefixes make the encoding injective. + # Sort keys, then concatenate key+value pairs with no separators. This + # matches MSAL .NET's ComputeAccessTokenExtCacheKey byte-for-byte. (See the + # docstring re: the Go #629 length-prefixed divergence.) key_str = "".join( - "{}:{}{}:{}".format(len(k), k, len(v), v) - for k, v in sorted(cache_components.items()) + k + cache_components[k] for k in sorted(cache_components.keys()) ) hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest() return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index bd4ca840..2a885e32 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -392,29 +392,6 @@ def test_different_client_claims_produce_different_hashes(self): def test_empty_client_claims_value_is_ignored(self): self.assertEqual("", _compute_ext_cache_key({"client_claims": ""})) - def test_length_prefixed_encoding_avoids_boundary_collision(self): - # Mirrors Go's TestCacheKeyComponentHashNoBoundaryCollision. With a plain - # key+value concatenation (no separators) these two distinct component - # sets both render to "axbYbZ" (sorted keys "a","b") and would collide, - # returning the wrong cached token. The length-prefixed encoding must keep - # them distinct. This matters because client_claims is arbitrary caller - # JSON that can embed another component's key (e.g. "fmi_path"). - h1 = _compute_ext_cache_key({"a": "xbY", "b": "Z"}) - h2 = _compute_ext_cache_key({"a": "x", "b": "YbZ"}) - self.assertNotEqual( - h1, h2, - "distinct cache key components must not produce the same hash") - - def test_client_claims_and_fmi_path_do_not_collide_at_boundary(self): - # Realistic surface: client_claims and fmi_path co-occur in - # acquire_token_for_client. A claims value that happens to contain the - # other component's key+value at a boundary must not collide. - h1 = _compute_ext_cache_key( - {"client_claims": "Xfmi_pathY", "fmi_path": "Z"}) - h2 = _compute_ext_cache_key( - {"client_claims": "X", "fmi_path": "Yfmi_pathZ"}) - self.assertNotEqual(h1, h2) - class TestClaimsHelpers(unittest.TestCase): """Tests for the shared _parse_claims_or_raise / _merge_claims helpers.""" @@ -565,48 +542,45 @@ def test_non_fmi_tokens_not_affected_by_fmi_cache(self): class TestCrossMsalCacheKeyCompatibility(unittest.TestCase): - """Verify that _compute_ext_cache_key matches MSAL Go's CacheExtKeyGenerator - (post collision-fix, AzureAD/microsoft-authentication-library-for-go#629). + """Verify that _compute_ext_cache_key produces hashes identical to MSAL .NET + (CoreHelpers.ComputeAccessTokenExtCacheKey). The algorithm: 1. Sort key-value pairs alphabetically by key (ordinal / case-sensitive) - 2. Concatenate length-prefixed pairs ("{len(k)}:{k}{len(v)}:{v}" per pair; - e.g. {"key1": "value1"} -> "4:key16:value1"). The length prefixes make - the encoding injective -- see TestClientClaimsCacheKey for the collision - guard. + 2. Concatenate them with no separators: "key1value1key2value2…" 3. SHA-256 hash 4. Base64url encode (no padding), lowercased - The expected hashes below are copied from MSAL Go's - authority_ext_cachekey_test.go (TestAppKeyWithCacheKeyComponent). + The expected hashes below are copied from MSAL .NET's CacheKeyExtensionTests.cs + (RunHappyPathTest, CacheExtEnsurePopKeysFunctionAsync). - NOTE: MSAL .NET's ComputeAccessTokenExtCacheKey still uses an *unprefixed* - concatenation, so these hashes are intentionally NOT byte-identical to current - .NET. The cache *key format* (the 'atext' segment layout, asserted below) still - matches both Go and .NET; only the trailing hash bytes differ. Caches are not - shared across languages, so within-process injectivity -- not cross-language - byte-parity -- is what matters for correctness. + NOTE: MSAL Go's CacheExtKeyGenerator has since switched to a *length-prefixed* + encoding (AzureAD/microsoft-authentication-library-for-go#629), so these hashes + are intentionally NOT byte-identical to current Go; Python deliberately tracks + .NET here. The cache *key format* (the 'atext' segment layout, asserted below) + still matches both Go and .NET. Caches are not shared across languages, so this + cross-language hash difference does not affect runtime correctness. """ - def test_two_params_hash_matches_go(self): - """Go expected: latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike""" + def test_two_params_hash_matches_dotnet(self): + """.NET expected: bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi""" result = _compute_ext_cache_key({"key1": "value1", "key2": "value2"}) - self.assertEqual("latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike", result) + self.assertEqual("bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", result) - def test_two_different_params_hash_matches_go(self): - """Go expected: jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq""" + def test_two_different_params_hash_matches_dotnet(self): + """.NET expected: 3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u""" result = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) - self.assertEqual("jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq", result) + self.assertEqual("3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u", result) - def test_five_params_hash_matches_go(self): - """Go expected: prrdp31y37ufw3lo7hly0oimjjvg_34m9ji30ocu4tw""" + def test_five_params_hash_matches_dotnet(self): + """.NET expected (full hash): rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e""" result = _compute_ext_cache_key({ "key3": "value3", "key4": "value4", "key5": "value5", "key6": "value6", "key7": "value7", }) - self.assertEqual("prrdp31y37ufw3lo7hly0oimjjvg_34m9ji30ocu4tw", result) + self.assertEqual("rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e", result) - def test_order_independence_matches_go(self): + def test_order_independence_matches_dotnet(self): """Same keys in different insertion order must produce the same hash (mirrors TestCacheKeyComponentHashConsistency in Go).""" h1 = _compute_ext_cache_key({"key3": "value3", "key4": "value4", @@ -627,9 +601,9 @@ def test_at_cache_key_uses_atext_credential_type(self): key = key_maker( home_account_id="hid", environment="env", client_id="cid", realm="realm", target="scope", - ext_cache_key="latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike") + ext_cache_key="bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi") self.assertEqual( - "hid-env-atext-cid-realm-scope-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike", + "hid-env-atext-cid-realm-scope-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", key) def test_at_cache_key_without_ext_uses_accesstoken(self): @@ -641,10 +615,9 @@ def test_at_cache_key_without_ext_uses_accesstoken(self): realm="realm", target="scope") self.assertEqual("hid-env-accesstoken-cid-realm-scope", key) - def test_atext_full_at_cache_key_format(self): - """The AT cache key *format* matches MSAL .NET's CacheKeyExtensionTests - layout ('-{env}-atext-{clientId}-{tenant}-{scopes}-{hash}'); only the - trailing hash now follows Go's length-prefixed encoding (see class note). + def test_dotnet_style_full_at_cache_key(self): + """Reproduce the exact cache key from MSAL .NET CacheKeyExtensionTests: + expectedCacheKey1 = '-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi' """ cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] @@ -656,11 +629,11 @@ def test_atext_full_at_cache_key_format(self): realm="common", target="r1/scope1 r1/scope2", ext_cache_key=ext_hash) - expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike" + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" self.assertEqual(expected, key) - def test_atext_second_full_at_cache_key_format(self): - """Second key-format vector (mirrors CacheKeyExtensionTests expectedCacheKey2).""" + def test_dotnet_style_second_cache_key(self): + """Reproduce CacheKeyExtensionTests expectedCacheKey2.""" cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] ext_hash = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) @@ -671,12 +644,13 @@ def test_atext_second_full_at_cache_key_format(self): realm="common", target="r1/scope1 r1/scope2", ext_cache_key=ext_hash) - expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-jjoe9jgfmdtnj0rzuetsqy7kzs2m1xfnjjxwsfxsrxq" + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u" self.assertEqual(expected, key) def test_go_style_at_cache_key(self): - """Reproduce the Go AccessToken.Key() format with Go's post-#629 hash: - 'testhid-env-atext-clientid-realm-user.read-{hash}'. + """Reproduce the Go AccessToken.Key() *format* (segment layout): + 'testhid-env-atext-clientid-realm-user.read-{hash}'. The hash follows our + .NET-matching encoding (see class note on the Go #629 divergence). """ cache = TokenCache() key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] @@ -688,5 +662,5 @@ def test_go_style_at_cache_key(self): realm="realm", target="user.read", ext_cache_key=ext_hash) - expected = "testhid-env-atext-clientid-realm-user.read-latlwkpewb_a0rcsmjvkecqt0_huumkw4sflzociike" + expected = "testhid-env-atext-clientid-realm-user.read-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" self.assertEqual(expected, key) From d1648e8a639a8919ad8270f51836a8bd04133a1a Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 1 Jul 2026 12:53:41 -0400 Subject: [PATCH 7/9] Preserve caller-provided data mapping in silent client-claims stash Use kwargs.get("data", {}) instead of the truthiness form in acquire_token_silent and acquire_token_silent_with_error so a caller-provided empty mapping is not replaced with a fresh dict. Matches the fmi_path / for_client sibling sites. Per Copilot review. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/application.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/msal/application.py b/msal/application.py index f8ed857a..1ca58e3e 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1546,7 +1546,7 @@ def acquire_token_silent( if not account: return None # A backward-compatible NO-OP to drop the account=None usage if forwarded_client_claims is not None: - kwargs["data"] = kwargs.get("data") or {} + kwargs["data"] = kwargs.get("data", {}) _stash_client_claims(forwarded_client_claims, kwargs["data"]) result = _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, @@ -1613,7 +1613,7 @@ def acquire_token_silent_with_error( if not account: return None # A backward-compatible NO-OP to drop the account=None usage if forwarded_client_claims is not None: - kwargs["data"] = kwargs.get("data") or {} + kwargs["data"] = kwargs.get("data", {}) _stash_client_claims(forwarded_client_claims, kwargs["data"]) return _clean_up(self._acquire_token_silent_with_error( scopes, account, authority=authority, force_refresh=force_refresh, From a7f73d38159216bfe5c3f2af124414035fc691eb Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 1 Jul 2026 13:06:28 -0400 Subject: [PATCH 8/9] Redact client_claims/claims in TokenCache.add() debug log The forwarded client_claims value (and the merged "claims") are kept in the cache-add event data only for ext_cache_key computation, but TokenCache.add() debug-logs event["data"] and previously masked only password/client_secret/refresh_token/assertion/user_fic. Add "client_claims" and "claims" to the masked fields so DEBUG logging cannot emit raw claim contents. Per Copilot review. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/token_cache.py | 5 +++++ tests/test_token_cache.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/msal/token_cache.py b/msal/token_cache.py index 1907d569..4c6155ba 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -371,6 +371,11 @@ def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info data=make_clean_copy(event.get("data", {}), ( "password", "client_secret", "refresh_token", "assertion", "user_federated_identity_credential", + # Client-originated claims may carry sensitive values; they are + # kept in data only for ext_cache_key computation, so redact them + # from the debug log (both the cache-key pseudo-param and the + # merged wire parameter). + "client_claims", "claims", )), response=make_clean_copy(event.get("response", {}), ( "id_token_claims", # Provided by broker diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 2a885e32..5ff224d6 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -58,6 +58,25 @@ def setUp(self): self.at_key_maker = self.cache.key_makers[ TokenCache.CredentialType.ACCESS_TOKEN] + def test_add_redacts_client_claims_in_debug_log(self): + # forwarded_client_claims (and the merged "claims") are kept in + # event["data"] only so they contribute to ext_cache_key. They may carry + # sensitive values, so TokenCache.add()'s DEBUG log must redact them. + secret = '{"access_token": {"nbf": {"essential": "SENSITIVE"}}}' + with self.assertLogs("msal.token_cache", level="DEBUG") as cm: + self.cache.add({ + "client_id": "my_client_id", + "scope": ["s1"], + "data": {"client_claims": secret, "claims": secret}, + "token_endpoint": "https://login.example.com/contoso/v2/token", + "response": build_response( + uid="uid", utid="utid", + expires_in=3600, access_token="an access token"), + }, now=1000) + logged = "\n".join(cm.output) + self.assertNotIn("SENSITIVE", logged) + self.assertIn("********", logged) + def testAddByAad(self): client_id = "my_client_id" id_token = build_id_token( From d2887198c23fb5afd377faa60c2ab1aa3b1df8fd Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 1 Jul 2026 13:17:05 -0400 Subject: [PATCH 9/9] Document client_claims as a cache-key-only pseudo-parameter in _compute_ext_cache_key Clarify that pseudo-parameters like client_claims intentionally feed the extended cache key hash while being stripped from the wire, so different client-originated claims route to separate cache entries. Doc-only. Per Copilot review. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- msal/token_cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/msal/token_cache.py b/msal/token_cache.py index 4c6155ba..75a14904 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -81,6 +81,11 @@ def _compute_ext_cache_key(data): This ensures tokens acquired with different parameter values (e.g., different FMI paths) are cached separately. + The hash may also intentionally include cache-key-only pseudo-parameters + such as ``client_claims`` -- these are stripped from the wire body by the + oauth2 layer but are retained in *data* precisely so that different + client-originated claims route to separate cache entries. + Returns an empty string when *data* has no hashable fields. The algorithm matches MSAL .NET's ``ComputeAccessTokenExtCacheKey``: sorted