1616
1717logger = sky_logging .init_logger (__name__ )
1818
19- TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node'
20- # Tag uniquely identifying all nodes of a cluster
21- TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
22- TAG_RAY_NODE_KIND = 'ray-node-type'
23-
2419_INSTANCE_RESOURCE_NOT_FOUND_PATTERN = re .compile (
2520 r'The resource \'projects/.*/zones/.*/instances/.*\' was not found' )
2621
@@ -66,7 +61,7 @@ def query_instances(
6661 assert provider_config is not None , (cluster_name_on_cloud , provider_config )
6762 zone = provider_config ['availability_zone' ]
6863 project_id = provider_config ['project_id' ]
69- label_filters = {TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
64+ label_filters = {constants . TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
7065
7166 handler : Type [
7267 instance_utils .GCPInstance ] = instance_utils .GCPComputeInstance
@@ -124,15 +119,15 @@ def _wait_for_operations(
124119 logger .debug (
125120 f'wait_for_compute_{ op_type } _operation: '
126121 f'Waiting for operation { operation ["name" ]} to finish...' )
127- handler .wait_for_operation (operation , project_id , zone )
122+ handler .wait_for_operation (operation , project_id , zone = zone )
128123
129124
130125def _get_head_instance_id (instances : List ) -> Optional [str ]:
131126 head_instance_id = None
132127 for inst in instances :
133128 labels = inst .get ('labels' , {})
134- if (labels .get (TAG_RAY_NODE_KIND ) == 'head' or
135- labels .get (TAG_SKYPILOT_HEAD_NODE ) == '1' ):
129+ if (labels .get (constants . TAG_RAY_NODE_KIND ) == 'head' or
130+ labels .get (constants . TAG_SKYPILOT_HEAD_NODE ) == '1' ):
136131 head_instance_id = inst ['name' ]
137132 break
138133 return head_instance_id
@@ -158,12 +153,14 @@ def _run_instances(region: str, cluster_name_on_cloud: str,
158153 resource : Type [instance_utils .GCPInstance ]
159154 if node_type == instance_utils .GCPNodeType .COMPUTE :
160155 resource = instance_utils .GCPComputeInstance
156+ elif node_type == instance_utils .GCPNodeType .MIG :
157+ resource = instance_utils .GCPManagedInstanceGroup
161158 elif node_type == instance_utils .GCPNodeType .TPU :
162159 resource = instance_utils .GCPTPUVMInstance
163160 else :
164161 raise ValueError (f'Unknown node type { node_type } ' )
165162
166- filter_labels = {TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
163+ filter_labels = {constants . TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
167164
168165 # wait until all stopping instances are stopped/terminated
169166 while True :
@@ -264,12 +261,16 @@ def get_order_key(node):
264261 if config .resume_stopped_nodes and to_start_count > 0 and stopped_instances :
265262 resumed_instance_ids = [n ['name' ] for n in stopped_instances ]
266263 if resumed_instance_ids :
267- for instance_id in resumed_instance_ids :
268- resource .start_instance (instance_id , project_id ,
269- availability_zone )
270- resource .set_labels (project_id , availability_zone , instance_id ,
271- labels )
272- to_start_count -= len (resumed_instance_ids )
264+ resumed_instance_ids = resource .start_instances (
265+ cluster_name_on_cloud , project_id , availability_zone ,
266+ resumed_instance_ids , labels )
267+ # In MIG case, the resumed_instance_ids will include the previously
268+ # PENDING and RUNNING instances. To avoid double counting, we need to
269+ # remove them from the resumed_instance_ids.
270+ ready_instances = set (resumed_instance_ids )
271+ ready_instances |= set ([n ['name' ] for n in running_instances ])
272+ ready_instances |= set ([n ['name' ] for n in pending_instances ])
273+ to_start_count = config .count - len (ready_instances )
273274
274275 if head_instance_id is None :
275276 head_instance_id = resource .create_node_tag (
@@ -281,9 +282,14 @@ def get_order_key(node):
281282
282283 if to_start_count > 0 :
283284 errors , created_instance_ids = resource .create_instances (
284- cluster_name_on_cloud , project_id , availability_zone ,
285- config .node_config , labels , to_start_count ,
286- head_instance_id is None )
285+ cluster_name_on_cloud ,
286+ project_id ,
287+ availability_zone ,
288+ config .node_config ,
289+ labels ,
290+ to_start_count ,
291+ total_count = config .count ,
292+ include_head_node = head_instance_id is None )
287293 if errors :
288294 error = common .ProvisionerError ('Failed to launch instances.' )
289295 error .errors = errors
@@ -387,7 +393,7 @@ def get_cluster_info(
387393 assert provider_config is not None , cluster_name_on_cloud
388394 zone = provider_config ['availability_zone' ]
389395 project_id = provider_config ['project_id' ]
390- label_filters = {TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
396+ label_filters = {constants . TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
391397
392398 handlers : List [Type [instance_utils .GCPInstance ]] = [
393399 instance_utils .GCPComputeInstance
@@ -415,7 +421,7 @@ def get_cluster_info(
415421 project_id ,
416422 zone ,
417423 {
418- ** label_filters , TAG_RAY_NODE_KIND : 'head'
424+ ** label_filters , constants . TAG_RAY_NODE_KIND : 'head'
419425 },
420426 lambda h : [h .RUNNING_STATE ],
421427 )
@@ -441,14 +447,14 @@ def stop_instances(
441447 assert provider_config is not None , cluster_name_on_cloud
442448 zone = provider_config ['availability_zone' ]
443449 project_id = provider_config ['project_id' ]
444- label_filters = {TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
450+ label_filters = {constants . TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
445451
446452 tpu_node = provider_config .get ('tpu_node' )
447453 if tpu_node is not None :
448454 instance_utils .delete_tpu_node (project_id , zone , tpu_node )
449455
450456 if worker_only :
451- label_filters [TAG_RAY_NODE_KIND ] = 'worker'
457+ label_filters [constants . TAG_RAY_NODE_KIND ] = 'worker'
452458
453459 handlers : List [Type [instance_utils .GCPInstance ]] = [
454460 instance_utils .GCPComputeInstance
@@ -510,9 +516,16 @@ def terminate_instances(
510516 if tpu_node is not None :
511517 instance_utils .delete_tpu_node (project_id , zone , tpu_node )
512518
513- label_filters = {TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
519+ use_mig = provider_config .get ('use_managed_instance_group' , False )
520+ if use_mig :
521+ # Deleting the MIG will also delete the instances.
522+ instance_utils .GCPManagedInstanceGroup .delete_mig (
523+ project_id , zone , cluster_name_on_cloud )
524+ return
525+
526+ label_filters = {constants .TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
514527 if worker_only :
515- label_filters [TAG_RAY_NODE_KIND ] = 'worker'
528+ label_filters [constants . TAG_RAY_NODE_KIND ] = 'worker'
516529
517530 handlers : List [Type [instance_utils .GCPInstance ]] = [
518531 instance_utils .GCPComputeInstance
@@ -555,7 +568,7 @@ def open_ports(
555568 project_id = provider_config ['project_id' ]
556569 firewall_rule_name = provider_config ['firewall_rule' ]
557570
558- label_filters = {TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
571+ label_filters = {constants . TAG_RAY_CLUSTER_NAME : cluster_name_on_cloud }
559572 handlers : List [Type [instance_utils .GCPInstance ]] = [
560573 instance_utils .GCPComputeInstance ,
561574 instance_utils .GCPTPUVMInstance ,
0 commit comments