Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 133 additions & 64 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@
# limitations under the License.
"""Profiling Utilities."""

import asyncio
from collections.abc import Mapping
import dataclasses
import concurrent.futures
import datetime
import json
import logging
import os
import threading
import time
from typing import Any
import urllib.parse

import fastapi
import grpc
import jax
from jax import numpy as jnp
from jax.extend import backend
from pathwaysutils import plugin_executable
import requests
import uvicorn
from pathwaysutils.proto import pathways_profiler_pb2
from pathwaysutils.proto import pathways_profiler_pb2_grpc


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -297,19 +296,27 @@ def start_trace(

_start_pathways_trace_from_profile_request(profile_request)

if jax.version.__version_info__ >= (0, 9, 2):
_original_start_trace(
log_dir=log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
profiler_options=profiler_options,
)
else:
_original_start_trace(
log_dir=log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
try:
if jax.version.__version_info__ >= (0, 9, 2):
_original_start_trace(
log_dir=log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
profiler_options=profiler_options,
)
else:
_original_start_trace(
log_dir=log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
)
except Exception:
_logger.exception(
"Failed to start JAX local trace, resetting pathways trace state"
)
with _profile_state.lock:
_profile_state.reset()
raise


def stop_trace() -> None:
Expand All @@ -326,54 +333,83 @@ def stop_trace() -> None:
_original_stop_trace()


_profiler_thread: threading.Thread | None = None
_profiler_server: grpc.Server | None = None
_profiler_server_lock = threading.Lock()


def start_server(port: int) -> None:
"""Starts the profiling server on port `port`.
class PathwaysProfilerServicer(
pathways_profiler_pb2_grpc.PathwaysProfilerServiceServicer
):
"""gRPC servicer for Pathways Profiler Service."""

The signature is slightly different from `jax.profiler.start_server`
because no handle to the server is returned because there is no
`xla_client.profiler.ProfilerServer` to return.
def Profile(
self,
request: pathways_profiler_pb2.ProfileRequest,
context: grpc.ServicerContext,
) -> pathways_profiler_pb2.ProfileResponse:
_logger.info("Received gRPC profile request for %s ms", request.duration_ms)
_logger.info("Writing profiling data to %s", request.repository_path)

Args:
port : The port to start the server on.
"""
def server_loop(port: int):
_logger.debug("Starting JAX profiler server on port %s", port)
app = fastapi.FastAPI()

@dataclasses.dataclass
class ProfilingConfig:
duration_ms: int
repository_path: str
try:
# jax.profiler.start_trace is monkey-patched to start pathways trace
jax.profiler.start_trace(request.repository_path)

elapsed = 0.0
duration_secs = request.duration_ms / 1000.0
while elapsed < duration_secs:
if not context.is_active():
_logger.warning("Client disconnected, aborting profile.")
raise RuntimeError("Client disconnected")
time.sleep(0.1)
elapsed += 0.1

except Exception as e: # pylint: disable=broad-exception-caught
_logger.exception("Error during profiling")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return pathways_profiler_pb2.ProfileResponse(status=f"Failed: {e}")
finally:
_logger.info("Stopping trace")
try:
jax.profiler.stop_trace()
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Failed to stop trace")

@app.post("/profiling")
async def profiling(pc: ProfilingConfig) -> Mapping[str, str]:
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
_logger.debug("Writing profiling data to %s", pc.repository_path)
await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)
await asyncio.sleep(pc.duration_ms / 1e3)
await asyncio.to_thread(jax.profiler.stop_trace)
return {"response": "profiling completed"}
return pathways_profiler_pb2.ProfileResponse(status="Completed")

uvicorn.run(app, host="0.0.0.0", port=port, log_level="debug")

global _profiler_thread
if _profiler_thread is not None:
raise RuntimeError("Only one profiler server can be active at a time.")
def start_server(port: int) -> None:
"""Starts the profiling server on port `port`.

_profiler_thread = threading.Thread(target=server_loop, args=(port,))
_profiler_thread.start()
Args:
port: The port to start the server on.
"""
global _profiler_server
with _profiler_server_lock:
if _profiler_server is not None:
raise RuntimeError("Only one profiler server can be active at a time.")

_logger.info("Starting JAX pathways profiler gRPC server on port %s", port)
server = grpc.server(concurrent.futures.ThreadPoolExecutor(max_workers=2))
pathways_profiler_pb2_grpc.add_PathwaysProfilerServiceServicer_to_server(
PathwaysProfilerServicer(), server
)
# Use ALTS credentials for secure communication inside Google
server_creds = grpc.alts_server_credentials()
server.add_secure_port(f"[::]:{port}", server_creds)
server.start()
_profiler_server = server


def stop_server() -> None:
"""Raises an error if there is no active profiler server.

Pathways profiling servers are not stoppable at this time.
"""
if _profiler_thread is None:
raise RuntimeError("No active profiler server.")
"""Stops the active profiler server."""
global _profiler_server
with _profiler_server_lock:
if _profiler_server is None:
raise RuntimeError("No active profiler server.")
_logger.info("Stopping JAX pathways profiler gRPC server")
_profiler_server.stop(grace=5.0)
_profiler_server = None


def collect_profile(
Expand All @@ -399,16 +435,49 @@ def collect_profile(
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")

request_json = {
"duration_ms": duration_ms,
"repository_path": log_dir,
}
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
# Use ALTS credentials for secure client connection
creds = grpc.alts_channel_credentials()
target = f"{host}:{port}"
_logger.info("Connecting to profiling server at %s using ALTS", target)
try:
response = requests.post(address, json=request_json)
response.raise_for_status()
except requests.exceptions.RequestException:
_logger.exception("Failed to collect profiling data")
with grpc.secure_channel(target, creds) as channel:
stub = pathways_profiler_pb2_grpc.PathwaysProfilerServiceStub(channel)
request = pathways_profiler_pb2.ProfileRequest(
duration_ms=duration_ms,
repository_path=str(log_dir),
)
timeout = (duration_ms / 1000.0) + 10.0
_logger.info("Triggering profile for %s ms", duration_ms)
response = stub.Profile(request, timeout=timeout)
_logger.info("Profiling response: %s", response.status)
if "Failed" in response.status:
return False
except grpc.RpcError as e:
e_call: Any = e
if e_call.code() == grpc.StatusCode.UNAVAILABLE:
_logger.error(
"Failed to connect to the profiling server at %s. "
"Please verify that the server is running on this port. "
"Note: If the server is running an older version of pathwaysutils, "
"it may be expecting HTTP (FastAPI) connections instead of gRPC, "
"which is incompatible with this client. "
"Error details: %s",
target,
e_call,
)
elif e_call.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
_logger.error(
"Profiling request timed out. The server might be unresponsive. "
"Error details: %s",
e_call,
)
else:
_logger.error(
"gRPC error occurred while collecting profile. "
"Error code: %s, details: %s",
e_call.code(),
e_call.details(),
)
return False

return True
Expand Down
46 changes: 46 additions & 0 deletions pathwaysutils/proto/pathways_profiler.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package pathwaysutils.profiler;

option java_multiple_files = true;
option java_package = "com.google.pathwaysutils.profiler";
option java_outer_classname = "PathwaysProfilerProto";

// PathwaysProfilerService provides an API to trigger and stop distributed
// profiling traces across Pathways workers.
service PathwaysProfilerService {
// Profiles the Pathways execution for the requested duration.
// The profiling data will be dumped to the specified repository path on GCS.
rpc Profile(ProfileRequest) returns (ProfileResponse) {}
}

message ProfileRequest {
// Duration in milliseconds to collect the profile.
int64 duration_ms = 1;

// The GCS repository path (e.g., gs://my-bucket/profiles) to save the
// profiling data.
string repository_path = 2;

// Optional session ID to group traces.
string session_id = 3;
}

message ProfileResponse {
// Status message indicating the result of the profiling operation.
string status = 1;
}
Loading
Loading