Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,18 @@ def stop_trace() -> None:
_profiler_thread: threading.Thread | None = None


def start_server(port: int) -> None:
def start_server(port: int, requires_backend: bool = True) -> None:
"""Starts the profiling server on port `port`.

The signature is slightly different from `jax.profiler.start_server`
because no handle to the server is returned because there is no
The signature matches `jax.profiler.start_server`, though no handle
to the server is returned because there is no
`xla_client.profiler.ProfilerServer` to return.

Args:
port : The port to start the server on.
port: The port to start the server on.
requires_backend: Unused in Pathways; accepted for parameter parity.
"""
del requires_backend
def server_loop(port: int):
_logger.debug("Starting JAX profiler server on port %s", port)
app = fastapi.FastAPI()
Expand Down Expand Up @@ -455,11 +457,11 @@ def stop_trace_patch() -> None:
jax.profiler.stop_trace = stop_trace_patch
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access

def start_server_patch(port: int) -> None:
def start_server_patch(port: int, requires_backend: bool = True) -> None:
_logger.debug(
"jax.profile.start_server patched with pathways' start_server"
)
start_server(port)
start_server(port, requires_backend=requires_backend)

jax.profiler.start_server = start_server_patch
jax._src.profiler.start_server = start_server_patch # pylint: disable=protected-access
Expand Down
4 changes: 2 additions & 2 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,9 @@ def test_monkey_patched_stop_trace(self, profiler_module):
def test_monkey_patched_start_server(self, profiler_module):
mocks = self._setup_monkey_patch()

profiler_module.start_server(1234)
profiler_module.start_server(1234, requires_backend=False)

mocks["start_server"].assert_called_once_with(1234)
mocks["start_server"].assert_called_once_with(1234, requires_backend=False)

@parameterized.named_parameters(
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
Expand Down
Loading