Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions temporalio/client/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,17 @@ class ActivityExecutionDescription(ActivityExecution):
long_poll_token: bytes | None
"""Token for follow-on long-poll requests. None if the activity is complete."""

raw_callbacks: Sequence[temporalio.api.activity.v1.CallbackInfo]
"""Underlying protobuf callbacks"""

@classmethod
async def _from_execution_info(
cls,
info: temporalio.api.activity.v1.ActivityExecutionInfo,
long_poll_token: bytes | None,
namespace: str,
data_converter: temporalio.converter.DataConverter,
callbacks: Sequence[temporalio.api.activity.v1.CallbackInfo],
) -> Self:
"""Create from raw proto activity execution info."""
# Decode heartbeat details if present
Expand Down Expand Up @@ -409,6 +413,7 @@ async def _from_execution_info(
typed_search_attributes=temporalio.converter.decode_typed_search_attributes(
info.search_attributes
),
raw_callbacks=callbacks,
)


Expand Down Expand Up @@ -691,6 +696,8 @@ def __init__(
*,
run_id: str | None = None,
result_type: type | None = None,
start_activity_response: None
| temporalio.api.workflowservice.v1.StartActivityExecutionResponse = None,
) -> None:
"""Create activity handle."""
self._client = client
Expand All @@ -700,6 +707,7 @@ def __init__(
self._known_outcome: (
temporalio.api.activity.v1.ActivityExecutionOutcome | None
) = None
self._start_activity_response = start_activity_response

@functools.cached_property
def _data_converter(self) -> temporalio.converter.DataConverter:
Expand Down
9 changes: 9 additions & 0 deletions temporalio/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,12 @@ async def start_activity(
start_delay: timedelta | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
# The following options should not be considered part of the public API. They
# are deliberately not exposed in overloads, and are not subject to any
# backwards compatibility guarantees.
callbacks: Sequence[Callback] = [],
links: Sequence[temporalio.api.common.v1.Link] = [],
request_id: str | None = None,
) -> ActivityHandle[ReturnType]:
"""Start an activity and return its handle.

Expand Down Expand Up @@ -1542,6 +1548,9 @@ async def start_activity(
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
priority=priority,
callbacks=callbacks,
links=links,
request_id=request_id,
)
)

Expand Down
23 changes: 22 additions & 1 deletion temporalio/client/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def _build_start_workflow_execution_request(
# Links are duplicated on request for compatibility with older server versions.
req.links.extend(links)

if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
if temporalio.nexus._operation_context._in_nexus_backing_start_context():
req.on_conflict_options.attach_request_id = True
req.on_conflict_options.attach_completion_callbacks = True
req.on_conflict_options.attach_links = True
Expand Down Expand Up @@ -567,6 +567,7 @@ async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]
input.id,
run_id=resp.run_id,
result_type=input.result_type,
start_activity_response=resp,
)

async def _build_start_activity_execution_request(
Expand Down Expand Up @@ -610,6 +611,8 @@ async def _build_start_activity_execution_request(
),
)

if input.request_id:
req.request_id = input.request_id
if input.schedule_to_close_timeout is not None:
req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout)
if input.start_to_close_timeout is not None:
Expand Down Expand Up @@ -645,6 +648,23 @@ async def _build_start_activity_execution_request(
# Set priority
req.priority.CopyFrom(input.priority._to_proto())

req.completion_callbacks.extend(
temporalio.api.common.v1.Callback(
nexus=temporalio.api.common.v1.Callback.Nexus(
url=callback.url,
header=callback.headers,
),
links=input.links,
)
for callback in input.callbacks
)
req.links.extend(input.links)

if temporalio.nexus._operation_context._in_nexus_backing_start_context():
req.on_conflict_options.attach_request_id = True
req.on_conflict_options.attach_completion_callbacks = True
req.on_conflict_options.attach_links = True

return req

async def cancel_activity(self, input: CancelActivityInput) -> None:
Expand Down Expand Up @@ -708,6 +728,7 @@ async def describe_activity(
is_local=False,
)
),
callbacks=resp.callbacks,
)

def list_activities(
Expand Down
4 changes: 4 additions & 0 deletions temporalio/client/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ class StartActivityInput:
headers: Mapping[str, temporalio.api.common.v1.Payload]
rpc_metadata: Mapping[str, str | bytes]
rpc_timeout: timedelta | None
# The following options are experimental and unstable.
callbacks: Sequence[Callback]
links: Sequence[temporalio.api.common.v1.Link]
request_id: str | None


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
wait_for_worker_shutdown_sync,
)
from ._operation_handlers import (
CancelActivityOptions,
CancelWorkflowRunOptions,
TemporalOperationHandler,
)
Expand All @@ -33,6 +34,7 @@

__all__ = (
"workflow_run_operation",
"CancelActivityOptions",
"CancelWorkflowRunOptions",
"Info",
"LoggerAdapter",
Expand Down
85 changes: 56 additions & 29 deletions temporalio/nexus/_link_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
logger = logging.getLogger(__name__)

_NEXUS_OPERATION_LINK_URL_PATH_REGEX = re.compile(
r"^/namespaces/(?P<namespace>[^/]+)/nexus-operations/(?P<operation_id>[^/]+)$"
r"^/namespaces/(?P<namespace>[^/]+)/nexus-operations/(?P<operation_id>[^/]+)/(?P<run_id>[^/]+)/details$"
)

_ACTIVITY_LINK_URL_PATH_REGEX = re.compile(
r"^/namespaces/(?P<namespace>[^/]+)/activities/(?P<activity_id>[^/]+)/(?P<run_id>[^/]+)/details$"
)

_WORFKLOW_LINK_URL_PATH_REGEX = re.compile(
Expand All @@ -31,13 +35,13 @@
class _LinkType(str, Enum):
WORKFLOW = temporalio.api.common.v1.Link.WorkflowEvent.DESCRIPTOR.full_name
NEXUS_OPERATION = temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name
ACTIVITY = temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name


LINK_EVENT_ID_PARAM_NAME = "eventID"
LINK_EVENT_TYPE_PARAM_NAME = "eventType"
LINK_REQUEST_ID_PARAM_NAME = "requestID"
LINK_REFERENCE_TYPE_PARAM_NAME = "referenceType"
LINK_RUN_ID_PARAM_NAME = "runID"

EVENT_REFERENCE_TYPE = "EventReference"
REQUEST_ID_REFERENCE_TYPE = "RequestIdReference"
Expand Down Expand Up @@ -84,6 +88,9 @@ def nexus_link_to_temporal_link(
case _LinkType.NEXUS_OPERATION:
return nexus_link_to_nexus_operation_link(nexus_link)

case _LinkType.ACTIVITY:
return nexus_link_to_activity_link(nexus_link)


def temporal_link_to_nexus_link(
temporal_link: temporalio.api.common.v1.Link,
Expand All @@ -92,16 +99,20 @@ def temporal_link_to_nexus_link(

Returns None when the Temporal link variant is missing.
"""
match temporal_link.WhichOneof("variant"):
variant = temporal_link.WhichOneof("variant")
match variant:
case "workflow_event":
return workflow_event_to_nexus_link(temporal_link.workflow_event)

case "nexus_operation":
return nexus_operation_to_nexus_link(temporal_link.nexus_operation)

case "activity" | "batch_job" | "workflow":
case "activity":
return activity_link_to_nexus_link(temporal_link.activity)

case "batch_job" | "workflow":

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just future proofing - if this comes through with a unknown value, it will be unhandled - should this be case _ and come after the none instead of specifying two values?

Also you might put the case value in the error message, it would be nice to know when debugging what value was actually recieved.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern matches against the generated proto files and is future proof in that our linting (and CI) will fail if a new unhandled variant is added.

Will add the case value to the error message for an unsupported type though!

raise NotImplementedError(
"only workflow_event and nexus operation links are supported"
f"only workflow_event, activity and nexus_operation links are supported, got {variant}"
)

case None:
Expand Down Expand Up @@ -151,22 +162,30 @@ def nexus_operation_to_nexus_link(
scheme = "temporal"
namespace = urllib.parse.quote(op_link.namespace, safe="")
operation_id = urllib.parse.quote(op_link.operation_id, safe="")
path = f"/namespaces/{namespace}/nexus-operations/{operation_id}"

query_params = ""
if op_link.run_id:
query_params = urllib.parse.urlencode(
{
LINK_RUN_ID_PARAM_NAME: op_link.run_id,
},
)
run_id = urllib.parse.quote(op_link.run_id, safe="")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be an error check for namespace, operation_id, and run_id to make sure they got values? Or are we certain these will always be present and so no need to check? Or is is just fine and we'll get a bad URL which will give us a reasonable error later?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields are always populated by the server when they come via proto so safe to assume they're there.

If a bug ever existed where they were not populated, they'd wind up as empty strings. If they wind up as empty the server validation will reject appropriately when they're sent along.

path = f"/namespaces/{namespace}/nexus-operations/{operation_id}/{run_id}/details"

# urllib will omit '//' from the url if netloc is empty so we add the scheme manually
url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}"
url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', '', ''))}"

return nexusrpc.Link(url=url, type=_LinkType.NEXUS_OPERATION.value)


def activity_link_to_nexus_link(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any failure handling in this - if urlunparse fails for example, is everything logged and handled in that call?

In looking for failure handling I looked for some unit tests on this method and didn't see any (though I did for nexus_link_to_activity_link). Should there be some tests?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

urlunparse only raises when invalid arguments are supplied so in this case where we constructing the input tuple from known string components we're safe. In terms of logging, should the conversion ever raise, the usage in _operation_context.py catches the conversion error and logs a warning that includes the error. This function is used indirectly via temporal_link_to_nexus_link.

This function is exercised in test_link_conversion_nexus_link_to_activity_link where there are a few conversions between nexusrpc.Link <--> proto links for the activity flavor.

activity: temporalio.api.common.v1.Link.Activity,
) -> nexusrpc.Link:
"""Convert an Activity link into a nexusrpc link."""
scheme = "temporal"
namespace = urllib.parse.quote(activity.namespace, safe="")
activity_id = urllib.parse.quote(activity.activity_id, safe="")
run_id = urllib.parse.quote(activity.run_id, safe="")
path = f"/namespaces/{namespace}/activities/{activity_id}/{run_id}/details"

url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', '', ''))}"

return nexusrpc.Link(url=url, type=_LinkType.ACTIVITY.value)


def nexus_link_to_workflow_event_link(
link: nexusrpc.Link,
) -> temporalio.api.common.v1.Link | None:
Expand Down Expand Up @@ -230,28 +249,36 @@ def nexus_link_to_nexus_operation_link(
)
return None

query_params = urllib.parse.parse_qs(url.query)

match query_params.get(LINK_RUN_ID_PARAM_NAME):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this rolls out, is there any chance links previously serialized with runId as a parameter will be hit? Should this be able to deserialize those to avoid a migration issue? Otherwise things might go out of sync and runId would get lost here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fortunately there are no versions of the server that accept nexus-operation links but use the runId as a parameter.

case [run_id_param]:
run_id = run_id_param
case [] | None:
run_id = ""
case _:
logger.warning(
f"Invalid Nexus link: {nexus_link}. Expected {LINK_RUN_ID_PARAM_NAME} to have at most 1 value"
)
return None

groups = match.groupdict()
nexus_op_link = temporalio.api.common.v1.Link.NexusOperation(
namespace=urllib.parse.unquote(groups["namespace"]),
operation_id=urllib.parse.unquote(groups["operation_id"]),
run_id=run_id,
run_id=urllib.parse.unquote(groups["run_id"]),
)
return temporalio.api.common.v1.Link(nexus_operation=nexus_op_link)


def nexus_link_to_activity_link(
nexus_link: nexusrpc.Link,
) -> temporalio.api.common.v1.Link | None:
"""Convert a nexus link into a Temporal Activity link."""
url = urllib.parse.urlparse(nexus_link.url)
match = _ACTIVITY_LINK_URL_PATH_REGEX.match(url.path)
if not match:
logger.warning(
f"Invalid Nexus link: {nexus_link}. Expected path to match {_ACTIVITY_LINK_URL_PATH_REGEX.pattern}"
)
return None

groups = match.groupdict()
activity_link = temporalio.api.common.v1.Link.Activity(
namespace=urllib.parse.unquote(groups["namespace"]),
activity_id=urllib.parse.unquote(groups["activity_id"]),
run_id=urllib.parse.unquote(groups["run_id"]),
)
return temporalio.api.common.v1.Link(activity=activity_link)


def _event_reference_to_query_params(
event_ref: temporalio.api.common.v1.Link.WorkflowEvent.EventReference,
) -> str:
Expand Down
Loading
Loading