Skip to content

Commit e6ee397

Browse files
gurcangercekZhanghao WuMichaelvll
authored
[GCP] GCE DWS Support (#3574)
* [GCP] initial take for dws support with migs * fix lint errors * dependency and format fix * refactor mig instance creation * fix * remove unecessary instance creation code for mig * Fix deletion * Fix instance template logic * Restart * format * format * move to REST APIs instead of python APIs * add multi-node back * Fix multi-node * Avoid spot * format * format * fix scheduling * fix cancel * Add smoke test * revert some changes * fix smoke * Fix * fix * Fix smoke * [GCP] Changing the config name for DWS support and fix for resize request cancellation (#5) * Fix config fields * fix cancel * Add loggings * remove useless codes --------- Co-authored-by: Zhanghao Wu <[email protected]> Co-authored-by: Zhanghao Wu <[email protected]>
1 parent 9a1aa5e commit e6ee397

File tree

10 files changed

+662
-71
lines changed

10 files changed

+662
-71
lines changed

docs/source/reference/config.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,30 @@ Available fields and semantics:
247247
- projects/my-project/reservations/my-reservation2
248248
249249
250+
# Managed instance group / DWS (optional).
251+
#
252+
# SkyPilot supports launching instances in a managed instance group (MIG)
253+
# which schedules the GPU instance creation through DWS, offering a better
254+
# availability. This feature is only applied when a resource request
255+
# contains GPU instances.
256+
managed_instance_group:
257+
# Duration for a created instance to be kept alive (in seconds, required).
258+
#
259+
# This is required for the DWS to work properly. After the
260+
# specified duration, the instance will be terminated.
261+
run_duration: 3600
262+
# Timeout for provisioning an instance by DWS (in seconds, optional).
263+
#
264+
# This timeout determines how long SkyPilot will wait for a managed
265+
# instance group to create the requested resources before giving up,
266+
# deleting the MIG and failing over to other locations. Larger timeouts
267+
# may increase the chance for getting a resource, but will blcok failover
268+
# to go to other zones/regions/clouds.
269+
#
270+
# Default: 900
271+
provision_timeout: 900
272+
273+
250274
# Identity to use for all GCP instances (optional).
251275
#
252276
# LOCAL_CREDENTIALS: The user's local credential files will be uploaded to

sky/clouds/gcp.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sky import clouds
1515
from sky import exceptions
1616
from sky import sky_logging
17+
from sky import skypilot_config
1718
from sky.adaptors import gcp
1819
from sky.clouds import service_catalog
1920
from sky.clouds.utils import gcp_utils
@@ -179,20 +180,31 @@ class GCP(clouds.Cloud):
179180
def _unsupported_features_for_resources(
180181
cls, resources: 'resources.Resources'
181182
) -> Dict[clouds.CloudImplementationFeatures, str]:
183+
unsupported = {}
182184
if gcp_utils.is_tpu_vm_pod(resources):
183-
return {
185+
unsupported = {
184186
clouds.CloudImplementationFeatures.STOP: (
185-
'TPU VM pods cannot be stopped. Please refer to: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources'
187+
'TPU VM pods cannot be stopped. Please refer to: '
188+
'https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources'
186189
)
187190
}
188191
if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources):
189192
# TPU node does not support multi-node.
190-
return {
191-
clouds.CloudImplementationFeatures.MULTI_NODE:
192-
('TPU node does not support multi-node. Please set '
193-
'num_nodes to 1.')
194-
}
195-
return {}
193+
unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = (
194+
'TPU node does not support multi-node. Please set '
195+
'num_nodes to 1.')
196+
# TODO(zhwu): We probably need to store the MIG requirement in resources
197+
# because `skypilot_config` may change for an existing cluster.
198+
# Clusters created with MIG (only GPU clusters) cannot be stopped.
199+
if (skypilot_config.get_nested(
200+
('gcp', 'managed_instance_group'), None) is not None and
201+
resources.accelerators):
202+
unsupported[clouds.CloudImplementationFeatures.STOP] = (
203+
'Managed Instance Group (MIG) does not support stopping yet.')
204+
unsupported[clouds.CloudImplementationFeatures.SPOT_INSTANCE] = (
205+
'Managed Instance Group with DWS does not support '
206+
'spot instances.')
207+
return unsupported
196208

197209
@classmethod
198210
def max_cluster_name_length(cls) -> Optional[int]:
@@ -493,6 +505,12 @@ def make_deploy_resources_variables(
493505

494506
resources_vars['tpu_node_name'] = tpu_node_name
495507

508+
managed_instance_group_config = skypilot_config.get_nested(
509+
('gcp', 'managed_instance_group'), None)
510+
use_mig = managed_instance_group_config is not None
511+
resources_vars['gcp_use_managed_instance_group'] = use_mig
512+
if use_mig:
513+
resources_vars.update(managed_instance_group_config)
496514
return resources_vars
497515

498516
def _get_feasible_launchable_resources(

sky/provision/gcp/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,15 @@
214214
MAX_POLLS = 60 // POLL_INTERVAL
215215
# Stopping instances can take several minutes, so we increase the timeout
216216
MAX_POLLS_STOP = MAX_POLLS * 8
217+
218+
TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node'
219+
# Tag uniquely identifying all nodes of a cluster
220+
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
221+
TAG_RAY_NODE_KIND = 'ray-node-type'
222+
TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name'
223+
224+
# MIG constants
225+
MANAGED_INSTANCE_GROUP_CONFIG = 'managed-instance-group'
226+
DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT = 900 # 15 minutes
227+
MIG_NAME_PREFIX = 'sky-mig-'
228+
INSTANCE_TEMPLATE_NAME_PREFIX = 'sky-it-'

sky/provision/gcp/instance.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616

1717
logger = sky_logging.init_logger(__name__)
1818

19-
TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node'
20-
# Tag uniquely identifying all nodes of a cluster
21-
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
22-
TAG_RAY_NODE_KIND = 'ray-node-type'
23-
2419
_INSTANCE_RESOURCE_NOT_FOUND_PATTERN = re.compile(
2520
r'The resource \'projects/.*/zones/.*/instances/.*\' was not found')
2621

@@ -66,7 +61,7 @@ def query_instances(
6661
assert provider_config is not None, (cluster_name_on_cloud, provider_config)
6762
zone = provider_config['availability_zone']
6863
project_id = provider_config['project_id']
69-
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
64+
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
7065

7166
handler: Type[
7267
instance_utils.GCPInstance] = instance_utils.GCPComputeInstance
@@ -124,15 +119,15 @@ def _wait_for_operations(
124119
logger.debug(
125120
f'wait_for_compute_{op_type}_operation: '
126121
f'Waiting for operation {operation["name"]} to finish...')
127-
handler.wait_for_operation(operation, project_id, zone)
122+
handler.wait_for_operation(operation, project_id, zone=zone)
128123

129124

130125
def _get_head_instance_id(instances: List) -> Optional[str]:
131126
head_instance_id = None
132127
for inst in instances:
133128
labels = inst.get('labels', {})
134-
if (labels.get(TAG_RAY_NODE_KIND) == 'head' or
135-
labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'):
129+
if (labels.get(constants.TAG_RAY_NODE_KIND) == 'head' or
130+
labels.get(constants.TAG_SKYPILOT_HEAD_NODE) == '1'):
136131
head_instance_id = inst['name']
137132
break
138133
return head_instance_id
@@ -158,12 +153,14 @@ def _run_instances(region: str, cluster_name_on_cloud: str,
158153
resource: Type[instance_utils.GCPInstance]
159154
if node_type == instance_utils.GCPNodeType.COMPUTE:
160155
resource = instance_utils.GCPComputeInstance
156+
elif node_type == instance_utils.GCPNodeType.MIG:
157+
resource = instance_utils.GCPManagedInstanceGroup
161158
elif node_type == instance_utils.GCPNodeType.TPU:
162159
resource = instance_utils.GCPTPUVMInstance
163160
else:
164161
raise ValueError(f'Unknown node type {node_type}')
165162

166-
filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
163+
filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
167164

168165
# wait until all stopping instances are stopped/terminated
169166
while True:
@@ -264,12 +261,16 @@ def get_order_key(node):
264261
if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances:
265262
resumed_instance_ids = [n['name'] for n in stopped_instances]
266263
if resumed_instance_ids:
267-
for instance_id in resumed_instance_ids:
268-
resource.start_instance(instance_id, project_id,
269-
availability_zone)
270-
resource.set_labels(project_id, availability_zone, instance_id,
271-
labels)
272-
to_start_count -= len(resumed_instance_ids)
264+
resumed_instance_ids = resource.start_instances(
265+
cluster_name_on_cloud, project_id, availability_zone,
266+
resumed_instance_ids, labels)
267+
# In MIG case, the resumed_instance_ids will include the previously
268+
# PENDING and RUNNING instances. To avoid double counting, we need to
269+
# remove them from the resumed_instance_ids.
270+
ready_instances = set(resumed_instance_ids)
271+
ready_instances |= set([n['name'] for n in running_instances])
272+
ready_instances |= set([n['name'] for n in pending_instances])
273+
to_start_count = config.count - len(ready_instances)
273274

274275
if head_instance_id is None:
275276
head_instance_id = resource.create_node_tag(
@@ -281,9 +282,14 @@ def get_order_key(node):
281282

282283
if to_start_count > 0:
283284
errors, created_instance_ids = resource.create_instances(
284-
cluster_name_on_cloud, project_id, availability_zone,
285-
config.node_config, labels, to_start_count,
286-
head_instance_id is None)
285+
cluster_name_on_cloud,
286+
project_id,
287+
availability_zone,
288+
config.node_config,
289+
labels,
290+
to_start_count,
291+
total_count=config.count,
292+
include_head_node=head_instance_id is None)
287293
if errors:
288294
error = common.ProvisionerError('Failed to launch instances.')
289295
error.errors = errors
@@ -387,7 +393,7 @@ def get_cluster_info(
387393
assert provider_config is not None, cluster_name_on_cloud
388394
zone = provider_config['availability_zone']
389395
project_id = provider_config['project_id']
390-
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
396+
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
391397

392398
handlers: List[Type[instance_utils.GCPInstance]] = [
393399
instance_utils.GCPComputeInstance
@@ -415,7 +421,7 @@ def get_cluster_info(
415421
project_id,
416422
zone,
417423
{
418-
**label_filters, TAG_RAY_NODE_KIND: 'head'
424+
**label_filters, constants.TAG_RAY_NODE_KIND: 'head'
419425
},
420426
lambda h: [h.RUNNING_STATE],
421427
)
@@ -441,14 +447,14 @@ def stop_instances(
441447
assert provider_config is not None, cluster_name_on_cloud
442448
zone = provider_config['availability_zone']
443449
project_id = provider_config['project_id']
444-
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
450+
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
445451

446452
tpu_node = provider_config.get('tpu_node')
447453
if tpu_node is not None:
448454
instance_utils.delete_tpu_node(project_id, zone, tpu_node)
449455

450456
if worker_only:
451-
label_filters[TAG_RAY_NODE_KIND] = 'worker'
457+
label_filters[constants.TAG_RAY_NODE_KIND] = 'worker'
452458

453459
handlers: List[Type[instance_utils.GCPInstance]] = [
454460
instance_utils.GCPComputeInstance
@@ -510,9 +516,16 @@ def terminate_instances(
510516
if tpu_node is not None:
511517
instance_utils.delete_tpu_node(project_id, zone, tpu_node)
512518

513-
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
519+
use_mig = provider_config.get('use_managed_instance_group', False)
520+
if use_mig:
521+
# Deleting the MIG will also delete the instances.
522+
instance_utils.GCPManagedInstanceGroup.delete_mig(
523+
project_id, zone, cluster_name_on_cloud)
524+
return
525+
526+
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
514527
if worker_only:
515-
label_filters[TAG_RAY_NODE_KIND] = 'worker'
528+
label_filters[constants.TAG_RAY_NODE_KIND] = 'worker'
516529

517530
handlers: List[Type[instance_utils.GCPInstance]] = [
518531
instance_utils.GCPComputeInstance
@@ -555,7 +568,7 @@ def open_ports(
555568
project_id = provider_config['project_id']
556569
firewall_rule_name = provider_config['firewall_rule']
557570

558-
label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
571+
label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
559572
handlers: List[Type[instance_utils.GCPInstance]] = [
560573
instance_utils.GCPComputeInstance,
561574
instance_utils.GCPTPUVMInstance,

0 commit comments

Comments
 (0)