Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 154 additions & 20 deletions msal/application.py

Large diffs are not rendered by default.

88 changes: 84 additions & 4 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from urllib.parse import urlparse # Python 3+
from collections import UserDict # Python 3+
from typing import List, Optional, Union # Needed in Python 3.7 & 3.8
from .token_cache import TokenCache
from .token_cache import TokenCache, _compute_ext_cache_key, _parse_claims_or_raise
from .individual_cache import _IndividualCache as IndividualCache
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
from .cloudshell import _is_running_in_cloud_shell
Expand All @@ -26,6 +26,12 @@ class ManagedIdentityError(ValueError):
pass


_CLIENT_CLAIMS_UNSUPPORTED_SOURCE = (
"forwarded_client_claims is only supported for the IMDS (Azure VM) managed identity "
"source. The detected source ({source}) does not support forwarding "
"client-originated claims.")


class ManagedIdentity(UserDict):
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
to acquire token for the specified managed identity.
Expand Down Expand Up @@ -261,6 +267,7 @@ def acquire_token_for_client(
*,
resource: str, # If/when we support scope, resource will become optional
claims_challenge: Optional[str] = None,
forwarded_client_claims: Optional[str] = None,
):
"""Acquire token for the managed identity.

Expand All @@ -280,6 +287,21 @@ def acquire_token_for_client(
even if the app developer did not opt in for the "CP1" client capability.
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.

:param forwarded_client_claims:
Optional.
A string representation of a JSON object containing
*client-originated* claims to forward to the identity endpoint.

Unlike ``claims_challenge`` (server-issued, which bypasses the cache),
tokens acquired with ``forwarded_client_claims`` **are cached**, and the cache
entry is keyed on the claims value. Send the *same* value on every
request that should share the cached token; different values produce
separate cache entries, so use stable, non-dynamic values to avoid
unbounded cache growth.

Only the IMDS (Azure VM) managed identity source supports this
parameter; other sources raise an error.

.. note::

Known issue: When an Azure VM has only one user-assigned managed identity,
Expand All @@ -294,6 +316,20 @@ def acquire_token_for_client(
client_id_in_cache = self._managed_identity.get(
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
now = time.time()
if forwarded_client_claims is not None:
if not isinstance(forwarded_client_claims, str):
raise ValueError(
"forwarded_client_claims must be a string, got {}".format(
type(forwarded_client_claims).__name__))
_parse_claims_or_raise(forwarded_client_claims) # Fail fast on malformed JSON
# Reject unsupported sources before any cache read, so an unsupported
# source never returns a cached client-claims token.
_raise_if_claims_unsupported_source()
# Client-originated claims isolate the cache: a distinct claims value gets
# a distinct cache entry. (Server-issued claims_challenge, by contrast,
# bypasses the cache and is keyed normally.)
ext_cache_key = _compute_ext_cache_key(
{"client_claims": forwarded_client_claims}) if forwarded_client_claims else None
if True: # Attempt cache search even if receiving claims_challenge,
# because we want to locate the existing token (if any) and refresh it
matches = self._token_cache.search(
Expand All @@ -304,6 +340,7 @@ def acquire_token_for_client(
environment=self.__instance,
realm=self._tenant,
home_account_id=None,
**({"ext_cache_key": ext_cache_key} if ext_cache_key else {}),
),
)
for entry in matches:
Expand Down Expand Up @@ -334,6 +371,7 @@ def acquire_token_for_client(
access_token_to_refresh.encode("utf-8")).hexdigest()
if access_token_to_refresh else None,
client_capabilities=self._client_capabilities,
client_claims=forwarded_client_claims,
)
if "access_token" in result:
expires_in = result.get("expires_in", 3600)
Expand All @@ -346,7 +384,7 @@ def acquire_token_for_client(
self.__instance, self._tenant),
response=result,
params={},
data={},
data={"client_claims": forwarded_client_claims} if forwarded_client_claims else {},
))
if "refresh_in" in result:
result["refresh_on"] = int(now + result["refresh_in"])
Expand Down Expand Up @@ -409,15 +447,42 @@ def get_managed_identity_source():
return DEFAULT_TO_VM


# Managed-identity sources that cannot forward client-originated claims. Keep in
# sync with the per-source guards inside _obtain_token (the backstop). Cloud Shell
# is intentionally absent: it falls through to the Azure VM / IMDS path, which
# does support claims.
_CLIENT_CLAIMS_UNSUPPORTED_SOURCES = {
SERVICE_FABRIC: "Service Fabric",
APP_SERVICE: "App Service",
MACHINE_LEARNING: "Machine Learning",
AZURE_ARC: "Azure Arc",
}


def _raise_if_claims_unsupported_source():
"""Fail fast -- before any cache read -- when the detected managed-identity
source cannot forward client-originated claims. ``_obtain_token`` enforces the
same rule per source as a backstop, but validating up front avoids a cache
lookup (and returning a cached token) for an unsupported source."""
name = _CLIENT_CLAIMS_UNSUPPORTED_SOURCES.get(get_managed_identity_source())
if name:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source=name))


def _obtain_token(
http_client, managed_identity, resource,
*,
access_token_sha256_to_refresh: Optional[str] = None,
client_capabilities: Optional[List[str]] = None,
client_claims: Optional[str] = None,
):
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
):
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Service Fabric"))
if managed_identity:
logger.debug(
"Ignoring managed_identity parameter. "
Expand All @@ -434,6 +499,9 @@ def _obtain_token(
client_capabilities=client_capabilities,
)
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="App Service"))
return _obtain_token_on_app_service(
http_client,
os.environ["IDENTITY_ENDPOINT"],
Expand All @@ -442,6 +510,9 @@ def _obtain_token(
resource,
)
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Machine Learning"))
# Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
return _obtain_token_on_machine_learning(
http_client,
Expand All @@ -452,14 +523,18 @@ def _obtain_token(
)
arc_endpoint = _get_arc_endpoint()
if arc_endpoint:
if client_claims:
raise ManagedIdentityError(
_CLIENT_CLAIMS_UNSUPPORTED_SOURCE.format(source="Azure Arc"))
if ManagedIdentity.is_user_assigned(managed_identity):
raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too
"Invalid managed_identity parameter. "
"Azure Arc supports only system-assigned managed identity, "
"See also "
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
return _obtain_token_on_arc(http_client, arc_endpoint, resource)
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)
return _obtain_token_on_azure_vm(
http_client, managed_identity, resource, client_claims=client_claims)


def _adjust_param(params, managed_identity, types_mapping=None):
Expand All @@ -469,14 +544,19 @@ def _adjust_param(params, managed_identity, types_mapping=None):
if id_name:
params[id_name] = managed_identity[ManagedIdentity.ID]

def _obtain_token_on_azure_vm(http_client, managed_identity, resource):

def _obtain_token_on_azure_vm(http_client, managed_identity, resource, client_claims=None):
# Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
logger.debug("Obtaining token via managed identity on Azure VM")
params = {
"api-version": "2018-02-01",
"resource": resource,
}
_adjust_param(params, managed_identity)
if client_claims:
# Forward client-originated claims as-is; IMDS decides which keys it
# accepts (no client-side allow-list, matching the other MSALs).
params["claims"] = client_claims # http_client.get url-encodes query params
resp = http_client.get(
os.getenv(
"AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254"
Expand Down
8 changes: 8 additions & 0 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
_data.update(data or {}) # So the content in data param prevails
_data = {k: v for k, v in _data.items() if v} # Clean up None values

# "client_claims" is a cache-key-only pseudo-parameter: callers merge its
# value into the standard "claims" body parameter upstream, and it is kept
# in the request data solely so it contributes to the extended cache key.
# It must not be sent on the wire. Popping it here (from this method's own
# local copy) keeps the wire body clean while the caller's data dict — used
# for the cache-add event — still carries it.
_data.pop("client_claims", None)

if _data.get('scope'):
_data['scope'] = self._stringify(_data['scope'])

Expand Down
77 changes: 73 additions & 4 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ def _compute_ext_cache_key(data):

Returns an empty string when *data* has no hashable fields.

The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator,
post-collision-fix): length-prefixed key/value pairs (sorted by key) are
concatenated and SHA256 hashed, then base64url encoded. The length prefixes
make the encoding injective, so two distinct component sets can never collide
onto the same cache key. (MSAL .NET's ``ComputeAccessTokenExtCacheKey`` still
uses an unprefixed concatenation, so the hash is intentionally not
byte-identical to current .NET; the cache *key format* still matches both.
Caches are not shared across languages, so this only affects within-process
isolation, where injectivity is what matters.)
"""
if not data:
return ""
Expand All @@ -94,14 +101,76 @@ def _compute_ext_cache_key(data):
}
if not cache_components:
return ""
# Sort keys for consistent hashing (matches Go implementation)
# Concatenate length-prefixed key/value pairs so component boundaries are
# unambiguous (matches Go's CacheExtKeyGenerator). A plain key+value
# concatenation with no separators can collide when one value happens to
# contain another component's key or value -- and client_claims is arbitrary
# caller-supplied JSON that may embed e.g. "fmi_path" at a boundary -- mapping
# two distinct component sets onto the same hash and returning the wrong
# cached token. Length prefixes make the encoding injective.
key_str = "".join(
k + cache_components[k] for k in sorted(cache_components.keys())
"{}:{}{}:{}".format(len(k), k, len(v), v)
for k, v in sorted(cache_components.items())
)
hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest()
return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()


def _parse_claims_or_raise(claims):
"""Parse a claims JSON string into a dict, or raise a friendly ``ValueError``.

The raw claims value is never included in the error message because it may
contain sensitive data. Mirrors MSAL .NET's ``ClaimsHelper.ParseClaimsOrThrow``.
"""
try:
parsed = json.loads(claims)
except (ValueError, TypeError) as ex:
# json.JSONDecodeError (malformed JSON) is a subclass of ValueError;
# TypeError is raised when *claims* is not a str/bytes/bytearray. Both
# are surfaced as the same friendly ValueError so every caller behaves
# consistently regardless of the bad input's type.
raise ValueError(
"The claims value is not valid JSON. "
"See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter."
) from ex
if not isinstance(parsed, dict):
# A valid JSON array, scalar, or the literal "null" is not a claims object.
raise ValueError(
"The claims value is not a valid JSON object. "
"See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter.")
return parsed


def _deep_merge_dict(base, overlay):
"""Recursively merge ``overlay`` into ``base``, returning a new dict.

Nested dicts are merged; for any other value type, ``overlay`` wins.
"""
result = dict(base)
for key, value in overlay.items():
if (key in result
and isinstance(result[key], dict) and isinstance(value, dict)):
result[key] = _deep_merge_dict(result[key], value)
else:
result[key] = value
return result


def _merge_claims(claims_a, claims_b):
"""Merge two claims JSON strings into a single JSON string.

If either side is empty/None, the other is returned as-is. Mirrors MSAL
.NET's ``ClaimsHelper.MergeClaimsObjects``.
"""
if not claims_a:
return claims_b
if not claims_b:
return claims_a
merged = _deep_merge_dict(
_parse_claims_or_raise(claims_a), _parse_claims_or_raise(claims_b))
return json.dumps(merged)


def is_subdict_of(small, big):
return dict(big, **small) == big

Expand Down
Loading
Loading