diff --git a/community/modules/compute/schedmd-slurm-gcp-v6-nodeset/outputs.tf b/community/modules/compute/schedmd-slurm-gcp-v6-nodeset/outputs.tf index 18ed74e2d5..409ae4f0cf 100644 --- a/community/modules/compute/schedmd-slurm-gcp-v6-nodeset/outputs.tf +++ b/community/modules/compute/schedmd-slurm-gcp-v6-nodeset/outputs.tf @@ -32,11 +32,6 @@ output "nodeset" { EOD } - precondition { - condition = var.accelerator_topology == null || var.enable_placement - error_message = "accelerator_topology requires enable_placement to be set to true." - } - precondition { condition = (var.accelerator_topology == null) || try(tonumber(split("x", var.accelerator_topology)[1]) % local.guest_accelerator[0].count == 0, false) error_message = "accelerator_topology must be divisible by number of gpus in machine." @@ -62,6 +57,16 @@ output "nodeset" { error_message = "Cannot use DWS Flex with `enable_placement`." } + precondition { + condition = var.accelerator_topology == null || (var.enable_placement || var.dws_flex.enabled) + error_message = "accelerator_topology requires either enable_placement to be set to true or DWS Flex to be enabled." + } + + precondition { + condition = startswith(var.machine_type, "a4x-") && (!var.dws_flex.enabled || var.accelerator_topology == "1x64") + error_message = "For A4X, cannot use DWS Flex with accelerator_topology other than 1x64." + } + precondition { condition = length(var.zones) == 0 || !var.dws_flex.enabled error_message = <<-EOD @@ -76,7 +81,12 @@ output "nodeset" { precondition { condition = !var.enable_spot_vm || !var.dws_flex.enabled - error_message = "Cannot use both Flex-Start and Spot VMs for provisioning." + error_message = "Cannot use both DWS Flex and Spot VMs for provisioning." + } + + precondition { + condition = startswith(var.machine_type, "a4x-") && (!var.dws_flex.enabled || (var.node_count_dynamic_max + var.node_count_static) % 16 == 0) + error_message = "For A4X, if DWS Flex is enabled, sum of `node_count_dynamic_max` and `node_count_static` should be a multiple of 16." } precondition { diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/mig_a4.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/mig_a4.py new file mode 100644 index 0000000000..03539decf1 --- /dev/null +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/mig_a4.py @@ -0,0 +1,294 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, List, Dict, Any + +from dataclasses import dataclass +from functools import lru_cache +from collections import defaultdict +import googleapiclient.discovery # type: ignore +import logging +import subprocess +import json + +import util +import resume + +log = logging.getLogger() + +@dataclass(frozen=True) +class MIG: + name: str + target_size: int + versions: List[str] + zone: str + + @classmethod + def from_json(cls, jo: object) -> "MIG": + return cls( + name=jo["name"], # type: ignore + target_size=jo["targetSize"], # type: ignore + versions=[v["instanceTemplate"] for v in jo.get("versions", [])], # type: ignore + zone=util.trim_self_link(jo["zone"]), # type: ignore + ) + +@lru_cache +def migs(lkp: util.Lookup, zone: str) -> Dict[str, MIG]: + resp = lkp.compute.instanceGroupManagers().list(project=lkp.project, zone=zone).execute() + return {m.name: m for m in [MIG.from_json(o) for o in resp.get('items', [])]} + + +@lru_cache +def get_mig(lkp: util.Lookup, zone: str, mig_name: str) -> Optional[MIG]: + try: + resp = lkp.compute.instanceGroupManagers().get( + project=lkp.project, zone=zone, instanceGroupManager=mig_name + ).execute() + return MIG.from_json(resp) + except googleapiclient.errors.HttpError as e: + if e.resp.status == 404: + return None + else: + raise + +def create_workload_policy_request(lkp: util.Lookup, nodeset: Dict, topology: str): + name = f"{lkp.cfg.slurm_cluster_name}-{nodeset['nodeset_name']}" + zone = nodeset["zone_policy_allow"][0] + region = '-'.join(zone.split('-')[:2]) + body = { + "name": name, + "region": region, + "workloadPolicy": { + "type": "HIGH_THROUGHPUT", + "accelerator_topology": topology, + }, + } + + workload_req = lkp.compute.resourcePolicies().insert( + project=lkp.project, region=region, body=body + ) + + return workload_req + +def create_mig_request(lkp: util.Lookup, mig: MIG): + assert len(mig.versions) == 1 + region = '-'.join(mig.zone.split('-')[:2]) + workload_policy_name = f"{'-'.join(mig.name.split('-')[:2])}" + + mig_req = lkp.compute.instanceGroupManagers().insert( + project=lkp.project, + zone=mig.zone, + body = dict( + name=mig.name, + versions=[dict( + instanceTemplate=mig.versions[0])], + targetSize=mig.target_size, + # Sensible defaults, allow for changes when needed + instanceLifecyclePolicy= { "defaultActionOnFailure": "DO_NOTHING" }, + resourcePolicies = { + "workloadPolicy": f"projects/{lkp.project}/regions/{region}/resourcePolicies/{workload_policy_name}" + }, + ) + ) + + return mig_req + + +def _allocate_node_to_mig(lkp: util.Lookup, nodes: List[str]) -> Dict[str, List[str]]: + def slice_id(node: str) -> int: + accelerator_topology = lkp.node_accelerator_topology(node) + topo = int(accelerator_topology.split("x")[1]) // lkp.node_template_info(node).gpu.count + return lkp.node_index(node) // topo + + res : Dict[str, List[str]] = defaultdict(list) + for _, nodes in util.groupby_unsorted(nodes, lkp.node_nodeset_name): + nodes = list(nodes) + ns = lkp.node_nodeset(nodes[0]) + for sid, nodes in util.groupby_unsorted(nodes, slice_id): + mig_name = f"{lkp.cfg.slurm_cluster_name}-{ns.nodeset_name}-{sid}" + res[mig_name] = list(nodes) + return res + +def submit_batch_request(requests, resume_data): + done, failed = util.batch_execute(requests, log_err=util.swallow_err) + + def ignore_err(e) -> bool: + return "resourceNotReady" in str(e) or "alreadyExists" in str(e) + + failed = [(n, _, e) for n, (_, e) in failed.items() if not ignore_err(e)] + if failed: + for request_id, request, error in failed: + log.warn(f"Error raised when attempting: {request_id}. Error: {error}") + request_body_dict = json.loads(request.body) + failed_nodes_in_mig = [instance['name'] for instance in request_body_dict.get('instances', [])] + resume.down_nodes_notify_jobs(failed_nodes_in_mig, f"{error}", resume_data) + + for operation_id, operation in done.items(): + try: + done[operation_id] = util.wait_for_operation(operation) + except Exception as e: + log.error(f"Unexpected error waiting for operation {operation_id}: {e}") + failed[operation_id] = (operation, e) + +def resume_slice_nodes(lkp: util.Lookup, nodes: List[str], resume_data): + mig_requests = {} + workload_requests = {} # type: ignore + + for mig_name, nodes in _allocate_node_to_mig(lkp, nodes).items(): + mig_req, workload_req = _resume_slice_nodes_requests(lkp, mig_name, nodes) + + if mig_req: + mig_requests[mig_name] = mig_req + if workload_req not in workload_requests.values(): # type: ignore + workload_requests[mig_name] = workload_req + + if workload_requests: + submit_batch_request(workload_requests, resume_data) + + if mig_requests: + submit_batch_request(mig_requests, resume_data) + +def _resume_slice_nodes_requests(lkp: util.Lookup, mig_name: str, nodes: List[str]): + assert nodes + model = nodes[0] + ns = lkp.node_nodeset(model) + zone = ns["zone_policy_allow"][0] + mig = migs(lkp, zone).get(mig_name) + mig_req = None + workload_req = None + + if not mig: + mig = MIG( + name=mig_name, + target_size=len(nodes), + zone=zone, + versions=[ns.instance_template]) + mig_req = create_mig_request(lkp, mig) + workload_req = create_workload_policy_request(lkp, ns, ns["accelerator_topology"]) + + return mig_req, workload_req + + +def suspend_slice_nodes(lkp: util.Lookup, nodes: List[str]): + requests = {} + for mig_name, nodes in _allocate_node_to_mig(lkp, nodes).items(): + request = _suspend_slice_nodes_request(lkp, mig_name, nodes) + if request: + requests[mig_name] = request + + done, failed = util.batch_execute(requests, log_err=util.swallow_err) + if failed: + failures = [f"{n}: {e}" for n, (_, e) in failed.items()] + if failures: + log.error(f"some mig nodes failed to delete: {failures}") + +def _suspend_slice_nodes_request(lkp: util.Lookup, mig_name: str, nodes: List[str]): + assert nodes + model = nodes[0] + ns = lkp.node_nodeset(model) + zone = ns["zone_policy_allow"][0] + + migs_in_zone = migs(lkp, zone) + mig_obj = migs_in_zone.get(mig_name) + + if mig_obj is None: + log.info(f"MIG {mig_name} not found (likely already deleted). Skipping suspend.") + return None + + # Check if the suspend request is for all nodes defined in the MIG's size. + if mig_obj.target_size != len(nodes): + log.warning( + f"Holding off suspension for MIG '{mig_name}'. " + f"Suspend request for {len(nodes)} nodes does not match the " + f"MIG target size of {mig_obj.target_size}. " + f"Waiting for all nodes in the slice to be idle." + ) + return None + + # If we are here, we are deleting the entire MIG atomically. + log.info(f"Deleting entire MIG '{mig_name}' with {mig_obj.target_size} nodes.") + op = lkp.compute.instanceGroupManagers().delete( + project=lkp.project, + zone=mig_obj.zone, + instanceGroupManager=mig_obj.name + ) + return op + + +def is_slice_node(node: str) -> bool: + return util.lookup().node_accelerator_topology(node) is not None + +def delete_workload_policies(lkp: util.Lookup, migs: List[MIG]): + requests = { + f"{mig.name}": lkp.compute.resourcePolicies().delete( + project=lkp.project, + region='-'.join(mig.zone.split('-')[:2]), + resourcePolicy=f"{'-'.join(mig.name.split('-')[:2])}") + for mig in migs + } + + done, failed = util.batch_execute(requests, log_err=util.swallow_err) + if failed: + def ignore_err(e) -> bool: + return "resourceInUseByAnotherResource" in str(e) + + failures = [f"{n}: {e}" for n, (_, e) in failed.items() if not ignore_err(e)] + if failures: + log.error(f"some workload policies failed to delete: {failures}") + log.info( + f"deleted {len(done)} of {len(migs)} workload policies ({util.to_hostlist(done.keys())})" + ) + +def delete_migs(lkp: util.Lookup, migs: List[MIG]): + requests = { + mig.name: lkp.compute.instanceGroupManagers().delete( + project=lkp.project, + zone=mig.zone, + instanceGroupManager=mig.name) + for mig in migs + } + + done, failed = util.batch_execute(requests, log_err=util.swallow_err) + if failed: + def ignore_err(e) -> bool: + return "resourceInUseByAnotherResource" in str(e) + + failures = [f"{n}: {e}" for n, (_, e) in failed.items() if not ignore_err(e)] + if failures: + log.error(f"some mig groups failed to delete: {failures}") + log.info( + f"deleted {len(done)} of {len(migs)} mig groups ({util.to_hostlist(done.keys())})" + ) + +def mig_details(lkp: util.Lookup, mig: MIG): + result = lkp.compute.instanceGroupManagers().get( + project=lkp.project, + zone=mig.zone, + instanceGroupManager=mig.name + ).execute() + + return result + +def list_instances_in_mig(project_id: str, zone: str, mig_name: str) -> List[str]: + instance_names = [] + result = util.lookup().compute.instanceGroupManagers().listManagedInstances( + project=project_id, + zone=zone, + instanceGroupManager=mig_name + ).execute() + + for item in result.get('managedInstances', []): + instance_names.append(item['instance'].split('/')[-1]) + return instance_names diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py index ea0012a0b1..7bb4912816 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py @@ -44,6 +44,7 @@ from util import lookup, ReservationDetails import tpu import mig_flex +import mig_a4 log = logging.getLogger() @@ -283,24 +284,49 @@ def chunk_nodes(nodes: List[str]): excl_job_id = job_id, placement_group=pn.placement, chunk_idx=i) - for job_id, placements in groups.items() for pn in placements if pn.nodes for i, nodes_chunk in enumerate(chunk_nodes(pn.nodes)) ] return {chunk.name: chunk for chunk in chunks} +def _filter_out_and_handle_slice_nodes(nodes: List[str], resume_data: Optional[ResumeData]) -> List[str]: + """ + Separates slice nodes from the list of nodes to be resumed. + + - Slice nodes with DWS Flex enabled are resumed via MIGs. + - All other nodes (non-slice, and slice without DWS Flex) are returned + to be resumed via bulk instance creation. + """ + lkp = lookup() + other_nodes, slice_nodes = util.separate(mig_a4.is_slice_node, nodes) + + if not slice_nodes: + return other_nodes + + a4x_dws_nodes, a4x_bulk_nodes = util.separate( + lambda node: lkp.node_nodeset(node).dws_flex.enabled, slice_nodes + ) + + if a4x_dws_nodes: + log.info(f"Resuming A4X DWS Flex nodes via MIGs: {to_hostlist(a4x_dws_nodes)}") + mig_a4.resume_slice_nodes(lkp, a4x_dws_nodes, resume_data) + + return other_nodes + a4x_bulk_nodes + def resume_nodes(nodes: List[str], resume_data: Optional[ResumeData]): """resume nodes in nodelist""" lkp = lookup() # Prevent dormant nodes associated with a reservation from being resumed nodes, dormant_res_nodes = util.separate(lkp.is_dormant_res_node, nodes) + nodes = _filter_out_and_handle_slice_nodes(nodes, resume_data) if dormant_res_nodes: log.warning(f"Resume was unable to resume reservation nodes={dormant_res_nodes}") down_nodes_notify_jobs(dormant_res_nodes, "Reservation is not active, nodes cannot be resumed", resume_data) + nodes, flex_managed = util.separate(lkp.is_provisioning_flex_node, nodes) if flex_managed: log.warning(f"Resume was unable to resume nodes={flex_managed} already managed by MIGs") diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py index 1bfdd5acce..be8f332e49 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/slurmsync.py @@ -45,6 +45,7 @@ from util import lookup from suspend import delete_instances import tpu +import mig_a4 import conf import watch_delete_vm_op @@ -156,12 +157,43 @@ def start_instances(node_list): execute_with_futures(tpu.start_tpu, tpu_start_data) +def get_mig_from_node(nodename: str): + accelerator_topo = util.lookup().node_accelerator_topology(nodename) + topo = int(accelerator_topo.split("x")[1]) // lookup().node_template_info(nodename).gpu.count + sid = lookup().node_index(nodename) // topo + + node_prefix = lookup().node_prefix(nodename) + zone = lookup().zone.split("/")[-1] + + expected_mig_name = f"{node_prefix}-{sid}" + migs_in_zone = mig_a4.migs(lookup(), zone) + for mig_name, mig_obj in migs_in_zone.items(): + if mig_name == expected_mig_name: + return mig_obj + return None + + +def _find_mig_node_action(nodename) -> NodeAction: + mig_obj = get_mig_from_node(nodename) + + inst = lookup().instance(nodename.split(".")[0]) + + if lookup().is_static_node(nodename): + if mig_obj == None: + return NodeActionPowerUp() + if inst == None: + return NodeActionPowerUp() + + return NodeActionUnchanged() + + def _find_dynamic_node_status() -> NodeAction: # TODO: cover more cases: # * delete dead dynamic nodes # * delete orhpaned instances return NodeActionUnchanged() # don't touch dynamic nodes + def get_fr_action(fr: FutureReservation, state:Optional[NodeState]) -> Optional[NodeAction]: now = util.now() if state is None: @@ -177,6 +209,7 @@ def get_fr_action(fr: FutureReservation, state:Optional[NodeState]) -> Optional[ msg = f"Reservation:{fr.name} is after its end-time" return NodeActionDown(reason=msg) + def _find_tpu_node_action(nodename, state) -> NodeAction: lkp = lookup() tpuobj = tpu.TPU.make(lkp.node_nodeset_name(nodename), lkp) @@ -242,6 +275,9 @@ def get_node_action(nodename: str) -> NodeAction: lkp = lookup() state = lkp.node_state(nodename) + if mig_a4.is_slice_node(nodename): + return _find_mig_node_action(nodename) + if lkp.node_is_gke(nodename): return NodeActionUnchanged() @@ -302,6 +338,8 @@ def get_node_action(nodename: str) -> NodeAction: if age < threshold: log.info(f"{nodename} not marked as orphan, it started less than {threshold.seconds}s ago ({age.seconds}s)") return NodeActionUnchanged() + if mig_a4.is_slice_node(nodename): + return NodeActionDown(reason="Orphaned slice node, awaiting group action") return NodeActionDelete() elif state is None: # if state is None here, the instance exists but it's not in Slurm @@ -338,7 +376,6 @@ def ignore_err(e) -> bool: ) - @lru_cache def _get_resource_policies_in_region(lkp: util.Lookup, region: str) -> list[Any]: res = [] @@ -359,6 +396,78 @@ def _get_resource_policies(lkp: util.Lookup) -> list[Any]: res.extend(_get_resource_policies_in_region(lkp, region)) return res + +def sync_migs(): + lkp = lookup() + + compute_instances = { + name for name, inst in lkp.instances().items() if inst.role == "compute" + } + slurm_nodes = set(lkp.slurm_nodes().keys()) + + nodesets = [] + zones = set() + for node in slurm_nodes: + ns_name = lkp.node_nodeset_name(node) + ns = lkp.node_nodeset(node) + if ns_name not in nodesets: + nodesets.append(ns_name) + zones.update(ns["zone_policy_allow"]) + + all_migs = {} + for zone in zones: + migs = mig_a4.migs(lkp, zone) + all_migs.update(migs) + + migs_to_delete = [] + for mig_name, mig_obj in all_migs.items(): + if not lkp.is_a4_dws_flex_mig(mig_obj): + continue + + cluster, nodeset, mig_index = mig_obj.name.split('-') + + result = mig_a4.mig_details(lookup(), mig_obj) + status = result.get("status", {}) + if not status.get("isStable", False): + log.info(f"Mig {mig_name} isn't stable yet") + continue + + mig_nodes = [] + for node in slurm_nodes: + # check if it's in the right nodeset + if (lkp.node_prefix(node) == f"{cluster}-{nodeset}") & mig_a4.is_slice_node(node): + accelerator_topology = lkp.node_accelerator_topology(node) + topo = int(accelerator_topology.split("x")[1]) // lkp.node_template_info(node).gpu.count + # check if it's in the right slice id + if int(mig_index) == (lkp.node_index(node) // topo): + mig_nodes.append(node) + + if len(mig_nodes) == 0: + # Delete MIG that is attributed to the cluster (by name) but not existing nodeset (result of nodeset deletion) + creationTimestamp = result.get("creationTimestamp", "") + parsed = util.parse_gcp_timestamp(creationTimestamp) + threshold = timedelta(seconds=300) + age = util.now() - parsed + if age > threshold: + migs_to_delete.append(mig_obj) + continue + + model = mig_nodes[0] + node_nodeset = lkp.node_nodeset(model) + if (mig_obj.versions[0] != node_nodeset.instance_template) and len(mig_nodes) == 0: + migs_to_delete.append(mig_obj) + continue + + max_index = node_nodeset["node_count_dynamic_max"] + node_nodeset["node_count_static"] - 1 + if all(lkp.node_index(node) > max_index for node in mig_nodes): + migs_to_delete.append(mig_obj) + continue + + if len(migs_to_delete) > 0: + mig_a4.delete_migs(lkp, migs_to_delete) + mig_a4.delete_workload_policies(lkp, migs_to_delete) + + def sync_placement_groups(): """Delete placement policies that are for jobs that have completed/terminated""" keep_states = frozenset( @@ -636,6 +745,11 @@ def main(): except Exception: log.exception("failed to sync DWS Flex MIGs") + try: + sync_migs() + except Exception: + log.exception("failed to sync migs") + try: sync_placement_groups() except Exception: @@ -677,3 +791,6 @@ def main(): main() except BlockingIOError: sys.exit(0) + main() + except BlockingIOError: + sys.exit(0) diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/suspend.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/suspend.py index ecef70f1cc..d54f1bdfe4 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/suspend.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/suspend.py @@ -29,6 +29,7 @@ from util import lookup import tpu import mig_flex +import mig_a4 import watch_delete_vm_op log = logging.getLogger() @@ -82,16 +83,17 @@ def delete_instances(instances): topic.publish(op, node) - - def suspend_nodes(nodes: List[str]) -> None: lkp = lookup() + other_nodes, tpu_nodes = util.separate(lkp.node_is_tpu, nodes) bulk_nodes, flex_nodes = util.separate(lkp.is_flex_node, other_nodes) + a4x_flex_nodes, other_flex_nodes = util.separate(util.is_a4x_node, flex_nodes) - mig_flex.suspend_flex_nodes(flex_nodes, lkp) + mig_flex.suspend_flex_nodes(other_flex_nodes, lkp) delete_instances(bulk_nodes) tpu.delete_tpu_instances(tpu_nodes) + mig_a4.suspend_slice_nodes(lkp, a4x_flex_nodes) def main(nodelist): diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_util.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_util.py index 69617d0301..d90251f299 100644 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_util.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tests/test_util.py @@ -666,3 +666,12 @@ def test_future_reservation_inactive(_): lkp._get_future_reservation.assert_called_once_with("manhattan", "danger", "zebra") lkp._get_reservation.assert_not_called() + +def test_node_accelerator_topology(): + cfg = TstCfg( + nodeset={ + "n": TstNodeset(accelerator_topology="2x2"), + } + ) + lkp = util.Lookup(cfg) + assert lkp.node_accelerator_topology("c-n-0") == "2x2" diff --git a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py index 35c478c1a5..710c02e6aa 100755 --- a/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py +++ b/community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py @@ -1316,6 +1316,9 @@ def to_hostnames(nodelist: str) -> List[str]: return hostnames +def swallow_err(_: str) -> None: + pass + def retry_exception(exc) -> bool: """return true for exceptions that should always be retried""" msg = str(exc) @@ -1687,7 +1690,7 @@ def is_dormant_res_node(self, node_name:str) -> bool: if res.delete_at_time is not None and res.assured_count <= 0: log.debug(f"DWS calendar reservation {res.bulk_insert_name} is not active yet, skipping resume.") return True - + return False def node_is_dyn(self, node_name=None) -> bool: @@ -1698,6 +1701,37 @@ def node_is_gke(self, node_name=None) -> bool: template_info = self.node_template_info(node_name) return self.template_is_gke(template_info) + def is_a4_dws_flex_mig(self, mig_obj: Any) -> bool: + """ + Checks if a given MIG object is for an A4 slice AND was provisioned + for DWS Flex. + """ + # mig_obj could be a dict from the API or the MIG dataclass from mig_a4.py + versions = getattr(mig_obj, 'versions', mig_obj.get('versions')) + mig_name = getattr(mig_obj, 'name', mig_obj.get('name')) + if not versions or not mig_name: + return False + try: + # Check 1: Verify it's an A4 MIG by checking the instance template's machine type. + template_url = versions[0] if isinstance(versions[0], str) else versions[0].get('instanceTemplate') + if not template_url: + return False + template_details = self.template_info(template_url) + machine_type = template_details.machineType + if not machine_type.startswith("a4x-"): + return False + # Check 2: Find the source nodeset by its name and check its DWS Flex flag. + for ns_name, nodeset in self.cfg.nodeset.items(): + mig_prefix = f"{self.cfg.slurm_cluster_name}-{ns_name}-" + if mig_name.startswith(mig_prefix): + # Found the matching nodeset. Check its DWS flag. + return nodeset.dws_flex.enabled + # If no matching nodeset was found in the loop. + return False + except Exception as e: + log.error(f"Could not determine DWS Flex status for MIG '{mig_name}': {e}") + return False + def nodeset_is_gke(self, nodeset=None) -> bool: template_info = self.template_info(nodeset.instance_template) return self.template_is_gke(template_info) @@ -1712,6 +1746,10 @@ def node_template(self, node_name=None) -> str: def node_template_info(self, node_name=None): return self.template_info(self.node_template(node_name)) + def node_accelerator_topology(self, node_name=None): + return self.node_nodeset(node_name).accelerator_topology + + def node_region(self, node_name=None): nodeset = self.node_nodeset(node_name) return parse_self_link(nodeset.subnetwork).region @@ -2227,3 +2265,6 @@ def scontrol_reconfigure(lkp: Lookup) -> None: run("sudo systemctl restart slurmctld.service", timeout=30) log.info("Running scontrol reconfigure") run(f"{lkp.scontrol} reconfigure") + +def is_a4x_node(node: str) -> bool: + return lookup().node_nodeset(node).machine_type.startswith("a4x-") diff --git a/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-blueprint.yaml b/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-blueprint.yaml index 4a0b06326a..6a7e075e27 100644 --- a/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-blueprint.yaml +++ b/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-blueprint.yaml @@ -43,6 +43,11 @@ vars: instance_image: project: $(vars.project_id) family: slurm-ubuntu2404-accelerator-arm64-64k + accelerator_topology: # supply accelerator topology + a4x_dws_flex_max_run_duration: 604800 # Maximum duration (in seconds) for which DWS Flex is needed. Should be b/w 600(10 minutes) and 604800(7 days) + #Provisioning model, select one (only one selection is valid), lack of selection will rely on on-demand capacity + a4x_dws_flex_enabled: false + a4x_enable_spot_vm: false a4x_reservation_name: "" # supply reservation name benchmark_dir: $(ghpc_stage("system_benchmarks")) @@ -591,11 +596,15 @@ deployment_groups: node_count_static: $(vars.a4x_cluster_size) node_count_dynamic_max: 0 enable_placement: true - accelerator_topology: 1x72 + accelerator_topology: $(vars.accelerator_topology) disk_type: hyperdisk-balanced instance_image_custom: true on_host_maintenance: TERMINATE reservation_name: $(vars.a4x_reservation_name) + enable_spot_vm: $(vars.a4x_enable_spot_vm) + dws_flex: + enabled: $(vars.a4x_dws_flex_enabled) + max_run_duration: $(vars.a4x_dws_flex_max_run_duration) additional_networks: $(concat( [{ diff --git a/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-deployment.yaml b/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-deployment.yaml index d309aef717..7694ffdec6 100644 --- a/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-deployment.yaml +++ b/examples/machine-learning/a4x-highgpu-4g/a4xhigh-slurm-deployment.yaml @@ -24,4 +24,9 @@ vars: region: # supply region with a4x-highgpu-4g capacity in reservation zone: # supply zone with a4x-highgpu-4g capacity in reservation a4x_cluster_size: # supply a4x-highgpu-4g reservation size + #Provisioning model, select one (only one selection is valid), lack of selection will rely on on-demand capacity + a4x_dws_flex_enabled: false + a4x_enable_spot_vm: false a4x_reservation_name: # supply a4x-highgpu-4g reservation name + accelerator_topology: # supply accelerator topology + a4x_dws_flex_max_run_duration: 604800 # Maximum duration (in seconds) for which DWS Flex is needed, Should be b/w 600s (10 minutes) and 604800s (7 days)