Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ output "nodeset" {
EOD
}

precondition {
condition = var.accelerator_topology == null || var.enable_placement
error_message = "accelerator_topology requires enable_placement to be set to true."
}

precondition {
condition = (var.accelerator_topology == null) || try(tonumber(split("x", var.accelerator_topology)[1]) % local.guest_accelerator[0].count == 0, false)
error_message = "accelerator_topology must be divisible by number of gpus in machine."
Expand All @@ -62,6 +57,16 @@ output "nodeset" {
error_message = "Cannot use DWS Flex with `enable_placement`."
}

precondition {
condition = var.accelerator_topology == null || (var.enable_placement || var.dws_flex.enabled)
error_message = "accelerator_topology requires either enable_placement to be set to true or DWS Flex to be enabled."
}

precondition {
condition = startswith(var.machine_type, "a4x-") && (!var.dws_flex.enabled || var.accelerator_topology == "1x64")
error_message = "For A4X, cannot use DWS Flex with accelerator_topology other than 1x64."
}

precondition {
condition = length(var.zones) == 0 || !var.dws_flex.enabled
error_message = <<-EOD
Expand All @@ -76,7 +81,12 @@ output "nodeset" {

precondition {
condition = !var.enable_spot_vm || !var.dws_flex.enabled
error_message = "Cannot use both Flex-Start and Spot VMs for provisioning."
error_message = "Cannot use both DWS Flex and Spot VMs for provisioning."
}

precondition {
condition = startswith(var.machine_type, "a4x-") && (!var.dws_flex.enabled || (var.node_count_dynamic_max + var.node_count_static) % 16 == 0)
error_message = "For A4X, if DWS Flex is enabled, sum of `node_count_dynamic_max` and `node_count_static` should be a multiple of 16."
}

precondition {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional, List, Dict, Any

from dataclasses import dataclass
from functools import lru_cache
from collections import defaultdict
import googleapiclient.discovery # type: ignore
import logging
import subprocess
import json

import util
import resume

log = logging.getLogger()

@dataclass(frozen=True)
class MIG:
name: str
target_size: int
versions: List[str]
zone: str

@classmethod
def from_json(cls, jo: object) -> "MIG":
return cls(
name=jo["name"], # type: ignore
target_size=jo["targetSize"], # type: ignore
versions=[v["instanceTemplate"] for v in jo.get("versions", [])], # type: ignore
zone=util.trim_self_link(jo["zone"]), # type: ignore
)

@lru_cache
def migs(lkp: util.Lookup, zone: str) -> Dict[str, MIG]:
resp = lkp.compute.instanceGroupManagers().list(project=lkp.project, zone=zone).execute()
return {m.name: m for m in [MIG.from_json(o) for o in resp.get('items', [])]}


@lru_cache
def get_mig(lkp: util.Lookup, zone: str, mig_name: str) -> Optional[MIG]:
try:
resp = lkp.compute.instanceGroupManagers().get(
project=lkp.project, zone=zone, instanceGroupManager=mig_name
).execute()
return MIG.from_json(resp)
except googleapiclient.errors.HttpError as e:
if e.resp.status == 404:
return None
else:
raise

def create_workload_policy_request(lkp: util.Lookup, nodeset: Dict, topology: str):
name = f"{lkp.cfg.slurm_cluster_name}-{nodeset['nodeset_name']}"
zone = nodeset["zone_policy_allow"][0]
region = '-'.join(zone.split('-')[:2])
body = {
"name": name,
"region": region,
"workloadPolicy": {
"type": "HIGH_THROUGHPUT",
"accelerator_topology": topology,
},
}

workload_req = lkp.compute.resourcePolicies().insert(
project=lkp.project, region=region, body=body
)

return workload_req

def create_mig_request(lkp: util.Lookup, mig: MIG):
assert len(mig.versions) == 1
region = '-'.join(mig.zone.split('-')[:2])
workload_policy_name = f"{'-'.join(mig.name.split('-')[:2])}"

mig_req = lkp.compute.instanceGroupManagers().insert(
project=lkp.project,
zone=mig.zone,
body = dict(
name=mig.name,
versions=[dict(
instanceTemplate=mig.versions[0])],
targetSize=mig.target_size,
# Sensible defaults, allow for changes when needed
instanceLifecyclePolicy= { "defaultActionOnFailure": "DO_NOTHING" },
resourcePolicies = {
"workloadPolicy": f"projects/{lkp.project}/regions/{region}/resourcePolicies/{workload_policy_name}"
},
)
)

return mig_req


def _allocate_node_to_mig(lkp: util.Lookup, nodes: List[str]) -> Dict[str, List[str]]:
def slice_id(node: str) -> int:
accelerator_topology = lkp.node_accelerator_topology(node)
topo = int(accelerator_topology.split("x")[1]) // lkp.node_template_info(node).gpu.count
return lkp.node_index(node) // topo

res : Dict[str, List[str]] = defaultdict(list)
for _, nodes in util.groupby_unsorted(nodes, lkp.node_nodeset_name):
nodes = list(nodes)
ns = lkp.node_nodeset(nodes[0])
for sid, nodes in util.groupby_unsorted(nodes, slice_id):
mig_name = f"{lkp.cfg.slurm_cluster_name}-{ns.nodeset_name}-{sid}"
res[mig_name] = list(nodes)
return res

def submit_batch_request(requests, resume_data):
done, failed = util.batch_execute(requests, log_err=util.swallow_err)

def ignore_err(e) -> bool:
return "resourceNotReady" in str(e) or "alreadyExists" in str(e)

failed = [(n, _, e) for n, (_, e) in failed.items() if not ignore_err(e)]
if failed:
for request_id, request, error in failed:
log.warn(f"Error raised when attempting: {request_id}. Error: {error}")
request_body_dict = json.loads(request.body)
failed_nodes_in_mig = [instance['name'] for instance in request_body_dict.get('instances', [])]
resume.down_nodes_notify_jobs(failed_nodes_in_mig, f"{error}", resume_data)

for operation_id, operation in done.items():
try:
done[operation_id] = util.wait_for_operation(operation)
except Exception as e:
log.error(f"Unexpected error waiting for operation {operation_id}: {e}")
failed[operation_id] = (operation, e)

def resume_slice_nodes(lkp: util.Lookup, nodes: List[str], resume_data):
mig_requests = {}
workload_requests = {} # type: ignore

for mig_name, nodes in _allocate_node_to_mig(lkp, nodes).items():
mig_req, workload_req = _resume_slice_nodes_requests(lkp, mig_name, nodes)

if mig_req:
mig_requests[mig_name] = mig_req
if workload_req not in workload_requests.values(): # type: ignore
workload_requests[mig_name] = workload_req

if workload_requests:
submit_batch_request(workload_requests, resume_data)

if mig_requests:
submit_batch_request(mig_requests, resume_data)

def _resume_slice_nodes_requests(lkp: util.Lookup, mig_name: str, nodes: List[str]):
assert nodes
model = nodes[0]
ns = lkp.node_nodeset(model)
zone = ns["zone_policy_allow"][0]
mig = migs(lkp, zone).get(mig_name)
mig_req = None
workload_req = None

if not mig:
mig = MIG(
name=mig_name,
target_size=len(nodes),
zone=zone,
versions=[ns.instance_template])
mig_req = create_mig_request(lkp, mig)
workload_req = create_workload_policy_request(lkp, ns, ns["accelerator_topology"])

return mig_req, workload_req


def suspend_slice_nodes(lkp: util.Lookup, nodes: List[str]):
requests = {}
for mig_name, nodes in _allocate_node_to_mig(lkp, nodes).items():
request = _suspend_slice_nodes_request(lkp, mig_name, nodes)
if request:
requests[mig_name] = request

done, failed = util.batch_execute(requests, log_err=util.swallow_err)
if failed:
failures = [f"{n}: {e}" for n, (_, e) in failed.items()]
if failures:
log.error(f"some mig nodes failed to delete: {failures}")

def _suspend_slice_nodes_request(lkp: util.Lookup, mig_name: str, nodes: List[str]):
assert nodes
model = nodes[0]
ns = lkp.node_nodeset(model)
zone = ns["zone_policy_allow"][0]

migs_in_zone = migs(lkp, zone)
mig_obj = migs_in_zone.get(mig_name)

if mig_obj is None:
log.info(f"MIG {mig_name} not found (likely already deleted). Skipping suspend.")
return None

# Check if the suspend request is for all nodes defined in the MIG's size.
if mig_obj.target_size != len(nodes):
log.warning(
f"Holding off suspension for MIG '{mig_name}'. "
f"Suspend request for {len(nodes)} nodes does not match the "
f"MIG target size of {mig_obj.target_size}. "
f"Waiting for all nodes in the slice to be idle."
)
return None

# If we are here, we are deleting the entire MIG atomically.
log.info(f"Deleting entire MIG '{mig_name}' with {mig_obj.target_size} nodes.")
op = lkp.compute.instanceGroupManagers().delete(
project=lkp.project,
zone=mig_obj.zone,
instanceGroupManager=mig_obj.name
)
return op


def is_slice_node(node: str) -> bool:
return util.lookup().node_accelerator_topology(node) is not None

def delete_workload_policies(lkp: util.Lookup, migs: List[MIG]):
requests = {
f"{mig.name}": lkp.compute.resourcePolicies().delete(
project=lkp.project,
region='-'.join(mig.zone.split('-')[:2]),
resourcePolicy=f"{'-'.join(mig.name.split('-')[:2])}")
for mig in migs
}

done, failed = util.batch_execute(requests, log_err=util.swallow_err)
if failed:
def ignore_err(e) -> bool:
return "resourceInUseByAnotherResource" in str(e)

failures = [f"{n}: {e}" for n, (_, e) in failed.items() if not ignore_err(e)]
if failures:
log.error(f"some workload policies failed to delete: {failures}")
log.info(
f"deleted {len(done)} of {len(migs)} workload policies ({util.to_hostlist(done.keys())})"
)

def delete_migs(lkp: util.Lookup, migs: List[MIG]):
requests = {
mig.name: lkp.compute.instanceGroupManagers().delete(
project=lkp.project,
zone=mig.zone,
instanceGroupManager=mig.name)
for mig in migs
}

done, failed = util.batch_execute(requests, log_err=util.swallow_err)
if failed:
def ignore_err(e) -> bool:
return "resourceInUseByAnotherResource" in str(e)

failures = [f"{n}: {e}" for n, (_, e) in failed.items() if not ignore_err(e)]
if failures:
log.error(f"some mig groups failed to delete: {failures}")
log.info(
f"deleted {len(done)} of {len(migs)} mig groups ({util.to_hostlist(done.keys())})"
)

def mig_details(lkp: util.Lookup, mig: MIG):
result = lkp.compute.instanceGroupManagers().get(
project=lkp.project,
zone=mig.zone,
instanceGroupManager=mig.name
).execute()

return result

def list_instances_in_mig(project_id: str, zone: str, mig_name: str) -> List[str]:
instance_names = []
result = util.lookup().compute.instanceGroupManagers().listManagedInstances(
project=project_id,
zone=zone,
instanceGroupManager=mig_name
).execute()

for item in result.get('managedInstances', []):
instance_names.append(item['instance'].split('/')[-1])
return instance_names
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from util import lookup, ReservationDetails
import tpu
import mig_flex
import mig_a4

log = logging.getLogger()

Expand Down Expand Up @@ -283,24 +284,49 @@ def chunk_nodes(nodes: List[str]):
excl_job_id = job_id,
placement_group=pn.placement,
chunk_idx=i)

for job_id, placements in groups.items()
for pn in placements if pn.nodes
for i, nodes_chunk in enumerate(chunk_nodes(pn.nodes))
]
return {chunk.name: chunk for chunk in chunks}

def _filter_out_and_handle_slice_nodes(nodes: List[str], resume_data: Optional[ResumeData]) -> List[str]:
"""
Separates slice nodes from the list of nodes to be resumed.

- Slice nodes with DWS Flex enabled are resumed via MIGs.
- All other nodes (non-slice, and slice without DWS Flex) are returned
to be resumed via bulk instance creation.
"""
lkp = lookup()
other_nodes, slice_nodes = util.separate(mig_a4.is_slice_node, nodes)

if not slice_nodes:
return other_nodes

a4x_dws_nodes, a4x_bulk_nodes = util.separate(
lambda node: lkp.node_nodeset(node).dws_flex.enabled, slice_nodes
)

if a4x_dws_nodes:
log.info(f"Resuming A4X DWS Flex nodes via MIGs: {to_hostlist(a4x_dws_nodes)}")
mig_a4.resume_slice_nodes(lkp, a4x_dws_nodes, resume_data)

return other_nodes + a4x_bulk_nodes


def resume_nodes(nodes: List[str], resume_data: Optional[ResumeData]):
"""resume nodes in nodelist"""
lkp = lookup()
# Prevent dormant nodes associated with a reservation from being resumed
nodes, dormant_res_nodes = util.separate(lkp.is_dormant_res_node, nodes)
nodes = _filter_out_and_handle_slice_nodes(nodes, resume_data)

if dormant_res_nodes:
log.warning(f"Resume was unable to resume reservation nodes={dormant_res_nodes}")
down_nodes_notify_jobs(dormant_res_nodes, "Reservation is not active, nodes cannot be resumed", resume_data)


nodes, flex_managed = util.separate(lkp.is_provisioning_flex_node, nodes)
if flex_managed:
log.warning(f"Resume was unable to resume nodes={flex_managed} already managed by MIGs")
Expand Down
Loading
Loading