diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 5a54586..0946472 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -36,6 +36,9 @@ _logger = logging.getLogger(__name__) +ProfileOptions = jax.profiler.ProfileOptions + + class _ProfileState: """Holds the state of an ongoing profiling session. diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index d6d926b..6b8238f 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -14,8 +14,8 @@ import json import logging -from unittest import mock from typing import Any +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -705,6 +705,9 @@ def test_start_trace_compatibility_error(self): "gs://test_bucket/test_dir", profiler_options=options ) + def test_export_profile_options(self): + self.assertEqual(profiling.ProfileOptions, jax.profiler.ProfileOptions) + if __name__ == "__main__": absltest.main()