Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 183 additions & 18 deletions pathwaysutils/experimental/gke/jobset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pathways JobSet generator and builder (Head Job Config)."""
"""Pathways JobSet generator and builder (with Worker Job Config)."""

import json
import logging
import math
from typing import Any, Mapping
from kubernetes import client

Expand All @@ -33,6 +34,7 @@

PATHWAYS_PROXY_PORT = 29000
PATHWAYS_RM_PORT = 29001
PATHWAYS_WORKER_PORT = 29005

MACHINE_TYPE_TO_TPU_VERSION_MAP = {
"tpu7x-standard-4t": "tpu7x",
Expand Down Expand Up @@ -77,7 +79,7 @@ def __init__(self, data):


class PathwaysJobSet:
"""Generates JobSet configuration for Pathways (with Head Job Config)."""
"""JobSet configuration generator for Pathways."""

def __init__(
self,
Expand All @@ -90,6 +92,8 @@ def __init__(
user_pod_template: Mapping[str, Any] | None = None,
main_container_name: str = "main",
max_restarts: int = 0,
max_slice_restarts: int = 0,
termination_grace_period_seconds: int | None = None,
pathways_version: str = "latest",
jobset_api_version: str = "v1alpha2",
elastic_slices: int = 0,
Expand All @@ -108,6 +112,8 @@ def __init__(
user_pod_template: Optional user pod template for the head job.
main_container_name: Name of the main container in user_pod_template.
max_restarts: Maximum number of restarts for the JobSet.
max_slice_restarts: Maximum number of slice restarts.
termination_grace_period_seconds: Optional termination grace period.
pathways_version: Version tag for Pathways images.
jobset_api_version: API version of JobSet.
elastic_slices: Number of elastic slices.
Expand All @@ -126,6 +132,19 @@ def __init__(
if not tpu_version:
raise ValueError(f"Unsupported TPU type: {tpu_type}")

gke_accel_type = MACHINE_TYPE_TO_GKE_ACCELERATOR_TYPE_MAP.get(
tpu_type.lower()
)

# Calculate VMs.
dims = [int(x) for x in topology.split("x")]
total_chips = math.prod(dims)
chips_per_vm = 8 if tpu_type.lower().endswith("8t") else 4
if total_chips < chips_per_vm:
num_vms = 1
else:
num_vms = total_chips // chips_per_vm

instance_type = f"{tpu_version}:{topology}"
image_tag = pathways_version

Expand All @@ -140,8 +159,17 @@ def __init__(
elastic_slices=elastic_slices,
)

# Build minimal worker template (placeholder)
self._worker_job_template = self._build_minimal_job_template("worker")
# Build worker template.
self._worker_job_template = self._build_worker_job_template(
pathways_dir=pathways_dir,
num_vms=num_vms,
chips_per_vm=chips_per_vm,
gke_accel_type=gke_accel_type,
topology=topology,
image_tag=image_tag,
max_slice_restarts=max_slice_restarts,
termination_grace_period_seconds=termination_grace_period_seconds,
)

self._success_policy = None
if user_pod_template:
Expand All @@ -150,20 +178,6 @@ def __init__(
"targetReplicatedJobs": [PATHWAYS_HEAD_JOB_NAME],
}

def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec:
"""Builds a minimal job template for a given role."""
pod_spec = client.V1PodSpec(
containers=[
client.V1Container(name=f"placeholder-{role}", image="ubuntu")
]
)
job_spec = client.V1JobSpec(
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={"role": role}), spec=pod_spec
)
)
return client.V1JobTemplateSpec(spec=job_spec)

def _build_head_job_template(
self,
pathways_dir: str,
Expand Down Expand Up @@ -365,6 +379,157 @@ def _build_head_job_template(
)
return head_job_template

def _build_worker_job_template(
self,
pathways_dir: str,
num_vms: int,
chips_per_vm: int,
gke_accel_type: str,
topology: str,
image_tag: str,
max_slice_restarts: int,
termination_grace_period_seconds: int | None,
) -> client.V1JobTemplateSpec:
worker_image = f"{DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE}:{image_tag}"

args = [
f"--resource_manager_address=$(PATHWAYS_HEAD):{PATHWAYS_RM_PORT}",
f"--server_port={PATHWAYS_WORKER_PORT}",
f"--gcs_scratch_location={pathways_dir}",
]
worker_env = [
client.V1EnvVar(name="TPU_MIN_LOG_LEVEL", value="0"),
client.V1EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="0"),
client.V1EnvVar(name="XCLOUD_ENVIRONMENT", value="GCP"),
client.V1EnvVar(name="MEGASCALE_GRPC_ENABLE_XOR_TRACER", value="false"),
client.V1EnvVar(
name="MEGASCALE_NUM_SLICES",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path="metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']"
)
),
),
client.V1EnvVar(
name="JOBSET_NAME",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
)
)
),
),
client.V1EnvVar(
name="REPLICATED_JOB_NAME",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path="metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']"
)
),
),
client.V1EnvVar(
name="MEGASCALE_SLICE_ID",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path="metadata.labels['jobset.sigs.k8s.io/job-index']"
)
),
),
client.V1EnvVar(
name="PATHWAYS_HEAD",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
)
)
),
),
client.V1EnvVar(
name="MEGASCALE_COORDINATOR_ADDRESS",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
)
)
),
),
]

worker_container = client.V1Container(
name="pathways-worker",
image=worker_image,
image_pull_policy="Always",
args=args,
env=worker_env,
ports=[
client.V1ContainerPort(
container_port=PATHWAYS_WORKER_PORT, protocol="TCP"
),
client.V1ContainerPort(container_port=29006, protocol="TCP"),
client.V1ContainerPort(container_port=8471, protocol="TCP"),
client.V1ContainerPort(container_port=8080, protocol="TCP"),
],
volume_mounts=[
client.V1VolumeMount(name="shared-tmp", mount_path="/tmp")
],
resources=client.V1ResourceRequirements(
limits={"google.com/tpu": str(chips_per_vm)}
),
)

node_selector = {
"cloud.google.com/gke-tpu-accelerator": gke_accel_type,
"cloud.google.com/gke-tpu-topology": topology,
}

backoff_limit = num_vms * 4
if max_slice_restarts > 0:
backoff_limit = num_vms * max_slice_restarts

worker_pod_spec = client.V1PodSpec(
containers=[worker_container],
node_selector=node_selector,
volumes=[
client.V1Volume(
name="shared-tmp",
host_path=client.V1HostPathVolumeSource(
path="/tmp", type="DirectoryOrCreate"
),
)
],
host_network=True,
dns_policy="ClusterFirstWithHostNet",
restart_policy="OnFailure",
)
if termination_grace_period_seconds is not None:
worker_pod_spec.termination_grace_period_seconds = (
termination_grace_period_seconds
)

worker_job_template = client.V1JobTemplateSpec(
metadata=client.V1ObjectMeta(),
spec=client.V1JobSpec(
backoff_limit=backoff_limit,
completion_mode="Indexed",
completions=num_vms,
parallelism=num_vms,
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
annotations={
"alpha.jobset.sigs.k8s.io/exclusive-topology": (
"cloud.google.com/gke-nodepool"
)
}
),
spec=worker_pod_spec,
),
),
)
return worker_job_template

def _compile_config(self) -> dict[str, Any]:
"""Compiles the JobSet configuration into a dictionary."""
with client.ApiClient() as api_client:
Expand Down
75 changes: 74 additions & 1 deletion pathwaysutils/test/experimental/gke/jobset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,83 @@ def test_non_headless_head_job(self):
def test_monkeypatch_restart_policy(self):
# Construct V1Container with restart_policy to test monkeypatch.
c = client.V1Container(
name="test", restart_policy="Always"
name="test",
restart_policy="Always"
) # pytype: disable=wrong-keyword-args
self.assertEqual(getattr(c, "restart_policy"), "Always")

def test_worker_job(self):
js = jobset.PathwaysJobSet(
name="test-jobset",
namespace="default",
pathways_dir="gs://test-bucket",
tpu_type="v5e",
topology="4x8",
num_slices=2,
max_slice_restarts=3,
termination_grace_period_seconds=60,
)
config = js.to_dict()

replicated_jobs = config["spec"]["replicatedJobs"]
worker_job = next(
j for j in replicated_jobs if j["name"] == "pathways-worker"
)
parsed_number_of_slices = worker_job["replicas"]
self.assertEqual(parsed_number_of_slices, 2)

# 4x8 v5e topology has 32 chips. v5e has 4 chips per VM.
# Total VMs = 32 / 4 = 8 VMs.
job_spec = worker_job["template"]["spec"]
self.assertEqual(job_spec["completions"], 8)
self.assertEqual(job_spec["parallelism"], 8)
# backoffLimit = num_vms * max_slice_restarts = 8 * 3 = 24
self.assertEqual(job_spec["backoffLimit"], 24)

pod_spec = job_spec["template"]["spec"]
self.assertTrue(pod_spec["hostNetwork"])
self.assertEqual(pod_spec["dnsPolicy"], "ClusterFirstWithHostNet")
self.assertEqual(pod_spec["restartPolicy"], "OnFailure")
self.assertEqual(pod_spec["terminationGracePeriodSeconds"], 60)

# Node selector
self.assertEqual(
pod_spec["nodeSelector"]["cloud.google.com/gke-tpu-accelerator"],
"tpu-v5-lite-podslice",
)
self.assertEqual(
pod_spec["nodeSelector"]["cloud.google.com/gke-tpu-topology"], "4x8"
)

# Container limits
container = pod_spec["containers"][0]
self.assertEqual(container["name"], "pathways-worker")
self.assertEqual(container["resources"]["limits"]["google.com/tpu"], "4")

def test_worker_job_small_topology(self):
js = jobset.PathwaysJobSet(
name="test-jobset",
namespace="default",
pathways_dir="gs://test-bucket",
tpu_type="v5e",
topology="1x1",
num_slices=1,
)
config = js.to_dict()

worker_job = next(
j
for j in config["spec"]["replicatedJobs"]
if j["name"] == "pathways-worker"
)
# 1x1 v5e topology has 1 chip. v5e has 4 chips per VM.
# Since total_chips (1) < chips_per_vm (4), num_vms should be 1.
job_spec = worker_job["template"]["spec"]
self.assertEqual(job_spec["completions"], 1)
self.assertEqual(job_spec["parallelism"], 1)
# default backoffLimit = num_vms * 4 = 1 * 4 = 4
self.assertEqual(job_spec["backoffLimit"], 4)


if __name__ == "__main__":
absltest.main()
Loading