Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from skyplane.utils.definitions import gateway_docker_image, tmp_log_dir
from skyplane.utils.fn import PathLike, do_parallel

from skyplane.compute.aws.aws_server import AWSServer
from skyplane.compute.gcp.gcp_server import GCPServer
from skyplane.compute.azure.azure_server import AzureServer

if TYPE_CHECKING:
from skyplane.api.provisioner import Provisioner

Expand Down Expand Up @@ -123,6 +127,7 @@ def _start_gateway(
use_bbr=self.transfer_config.use_bbr, # TODO: remove
use_compression=self.transfer_config.use_compression,
use_socket_tls=self.transfer_config.use_socket_tls,
instance_path=gateway_node.gateway_instance_path, # TODO: better way of mapping the path of VM src/dst
)

def provision(
Expand Down Expand Up @@ -163,13 +168,16 @@ def provision(
assert (
cloud_provider != "cloudflare"
), f"Cannot create VMs in certain cloud providers: check planner output {self.topology.to_dict()}"
self.provisioner.add_task(
cloud_provider=cloud_provider,
region=region,
vm_type=node.vm_type or getattr(self.transfer_config, f"{cloud_provider}_instance_class"),
spot=getattr(self.transfer_config, f"{cloud_provider}_use_spot_instances"),
autoterminate_minutes=self.transfer_config.autoterminate_minutes,
)

# Only provision if it is not VM source or destination
if node.gateway_instance_id is None:
self.provisioner.add_task(
cloud_provider=cloud_provider,
region=region,
vm_type=node.vm_type or getattr(self.transfer_config, f"{cloud_provider}_instance_class"),
spot=getattr(self.transfer_config, f"{cloud_provider}_use_spot_instances"),
autoterminate_minutes=self.transfer_config.autoterminate_minutes,
)

# initialize clouds
self.provisioner.init_global(aws=is_aws_used, azure=is_azure_used, gcp=is_gcp_used, ibmcloud=is_ibmcloud_used)
Expand All @@ -186,8 +194,19 @@ def provision(
servers_by_region = defaultdict(list)
for s in servers:
servers_by_region[s.region_tag].append(s)

for node in self.topology.get_gateways():
instance = servers_by_region[node.region_tag].pop()
if node.region_tag not in servers_by_region:
if node.region_tag.startswith("aws"):
instance = AWSServer(node.region_tag, node.gateway_instance_id, key_path=node.gateway_key_path)
elif node.region_tag.startswith("azure"):
instance = AzureServer(node.gateway_instance_id)
elif node.region_tag.startswith("gcp"):
instance = GCPServer(node.region_tag, node.gateway_instance_id)
else:
raise Exception(f"Invalid region tag: {node.region_tag}")
else:
instance = servers_by_region[node.region_tag].pop()
self.bound_nodes[node] = instance

# set ip addresses (for gateway program generation)
Expand Down
15 changes: 14 additions & 1 deletion skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided
from skyplane.planner.planner import (
MulticastDirectPlanner,
DirectPlannerSourceOneSided,
DirectPlannerDestOneSided,
DirectPlannerVMSource,
DirectPlannerVMDest,
DirectPlannerVMSourceDest,
)
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
from skyplane.utils.definitions import tmp_log_dir
Expand Down Expand Up @@ -67,6 +74,12 @@ def __init__(
self.planner = DirectPlannerSourceOneSided(self.max_instances, self.n_connections, self.transfer_config)
elif self.planning_algorithm == "dst_one_sided":
self.planner = DirectPlannerDestOneSided(self.max_instances, self.n_connections, self.transfer_config)
elif self.planning_algorithm == "vm_source":
self.planner = DirectPlannerVMSource(self.max_instances, 64, self.transfer_config)
elif self.planning_algorithm == "vm_dest":
self.planner = DirectPlannerVMDest(self.max_instances, 64, self.transfer_config)
elif self.planning_algorithm == "vm_to_vm":
self.planner = DirectPlannerVMSourceDest(self.max_instances, 64, self.transfer_config)
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

Expand Down
86 changes: 64 additions & 22 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@
from queue import Queue

from abc import ABC
from typing import TYPE_CHECKING, Callable, Generator, List, Optional, Tuple, TypeVar, Dict
from typing import (
TYPE_CHECKING,
Callable,
Generator,
List,
Optional,
Tuple,
TypeVar,
Dict,
)

from abc import ABC

Expand All @@ -27,6 +36,7 @@
from skyplane.chunk import Chunk
from skyplane.obj_store.storage_interface import StorageInterface
from skyplane.obj_store.object_store_interface import ObjectStoreObject, ObjectStoreInterface
from skyplane.obj_store.vm_interface import VMInterface
from skyplane.utils import logger
from skyplane.utils.definitions import MB
from skyplane.utils.fn import do_parallel
Expand Down Expand Up @@ -108,10 +118,12 @@ def _run_multipart_chunk_thread(
upload_id_mapping = {}
for dest_iface in self.dst_ifaces:
dest_object = dest_objects[dest_iface.region_tag()]

upload_id = dest_iface.initiate_multipart_upload(dest_object.key, mime_type=mime_type)
# print(f"Created upload id for key {dest_object.key} with upload id {upload_id} for bucket {dest_iface.bucket_name}")
# store mapping between key and upload id for each region
upload_id_mapping[dest_iface.region_tag()] = (src_object.key, upload_id)

out_queue_chunks.put(GatewayMessage(upload_id_mapping=upload_id_mapping)) # send to output queue

# get source and destination object and then compute number of chunks
Expand Down Expand Up @@ -164,7 +176,15 @@ def _run_multipart_chunk_thread(
metadata = (block_ids, mime_type)

self.multipart_upload_requests.append(
dict(upload_id=upload_id, key=dest_object.key, parts=parts, region=region, bucket=bucket, metadata=metadata)
dict(
upload_id=upload_id,
key=dest_object.key,
parts=parts,
region=region,
bucket=bucket,
metadata=metadata,
vm=True if dest_iface.provider == "vm" else False,
)
)
else:
mime_type = None
Expand Down Expand Up @@ -291,24 +311,34 @@ def transfer_pair_generator(
logger.fs.exception(e)
raise e from None

if dest_provider == "aws":
from skyplane.obj_store.s3_interface import S3Object

dest_obj = S3Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "azure":
from skyplane.obj_store.azure_blob_interface import AzureBlobObject
if isinstance(dst_iface, VMInterface):
# VM destination
from skyplane.obj_store.vm_interface import VMFile

dest_obj = AzureBlobObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "gcp":
from skyplane.obj_store.gcs_interface import GCSObject
host_ip = dst_iface.host_ip()

dest_obj = GCSObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "cloudflare":
from skyplane.obj_store.r2_interface import R2Object
dest_obj = VMFile(provider=dest_provider, bucket=host_ip, key=dest_key)

dest_obj = R2Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
else:
raise ValueError(f"Invalid dest_region {dest_region}, unknown provider")
# Bucket destination
if dest_provider == "aws":
from skyplane.obj_store.s3_interface import S3Object

dest_obj = S3Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "azure":
from skyplane.obj_store.azure_blob_interface import AzureBlobObject

dest_obj = AzureBlobObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "gcp":
from skyplane.obj_store.gcs_interface import GCSObject

dest_obj = GCSObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
elif dest_provider == "cloudflare":
from skyplane.obj_store.r2_interface import R2Object

dest_obj = R2Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key)
else:
raise ValueError(f"Invalid dest_region {dest_region}, unknown provider")
dest_objs[dst_iface.region_tag()] = dest_obj

# assert that all destinations share the same post-fix key
Expand All @@ -332,7 +362,7 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
multipart_chunk_threads = []

# start chunking threads
if self.transfer_config.multipart_enabled:
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
for _ in range(self.concurrent_multipart_chunk_threads):
t = threading.Thread(
target=self._run_multipart_chunk_thread,
Expand All @@ -346,7 +376,11 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
for transfer_pair in transfer_pair_generator:
# print("transfer_pair", transfer_pair.src_obj.key, transfer_pair.dst_objs)
src_obj = transfer_pair.src_obj
if self.transfer_config.multipart_enabled and src_obj.size > self.transfer_config.multipart_threshold_mb * MB:
if (
self.transfer_config.multipart_enabled
# and not isinstance(self.dst_ifaces[0], VMInterface)
and src_obj.size > self.transfer_config.multipart_threshold_mb * MB
):
multipart_send_queue.put(transfer_pair)
else:
if transfer_pair.src_obj.size == 0:
Expand All @@ -362,12 +396,12 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
)
)

if self.transfer_config.multipart_enabled:
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
# drain multipart chunk queue and yield with updated chunk IDs
while not multipart_chunk_queue.empty():
yield multipart_chunk_queue.get()

if self.transfer_config.multipart_enabled:
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
# wait for processing multipart requests to finish
logger.fs.debug("Waiting for multipart threads to finish")
# while not multipart_send_queue.empty():
Expand Down Expand Up @@ -697,11 +731,14 @@ def finalize(self):
for req in self.multipart_transfer_list:
if "region" not in req or "bucket" not in req:
raise Exception(f"Invalid multipart upload request: {req}")
groups[(req["region"], req["bucket"])].append(req)
groups[(req["region"], req["bucket"], req["vm"])].append(req)
for key, group in groups.items():
region, bucket = key
region, bucket, vm = key
batch_len = max(1, len(group) // 128)
batches = [group[i : i + batch_len] for i in range(0, len(group), batch_len)]
print(f"region: {region}, bucket: {bucket}")
if vm:
region = "vm:" + region
obj_store_interface = StorageInterface.create(region, bucket)

def complete_fn(batch):
Expand All @@ -723,14 +760,19 @@ def verify(self):
def verify_region(i):
dst_iface = self.dst_ifaces[i]
dst_prefix = self.dst_prefixes[i]
print("Dst prefix: ", dst_prefix)

# gather destination key mapping for this region
dst_keys = {pair.dst_objs[dst_iface.region_tag()].key: pair.src_obj for pair in self.transfer_list}
print(f"Destination key mappings: {dst_keys}")

# list and check destination prefix
for obj in dst_iface.list_objects(dst_prefix):
print(f"Object listed: {obj.key}")
# check metadata (src.size == dst.size) && (src.modified <= dst.modified)
src_obj = dst_keys.get(obj.key)
print(f"src_obj: {src_obj}")
print(f"Object: {obj}")
if src_obj and src_obj.size == obj.size and src_obj.last_modified <= obj.last_modified:
del dst_keys[obj.key]

Expand Down
4 changes: 2 additions & 2 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import skyplane.cli.cli_cloud
import skyplane.cli.cli_config

# import skyplane.cli.experiments # disable experiments
import skyplane.cli.experiments # disable experiments
from skyplane import compute

from skyplane.cli.cli_init import init
Expand All @@ -30,7 +30,7 @@
name="init",
help="Initialize the Skyplane CLI with your cloud credentials",
)(init)
# app.add_typer(skyplane.cli.experiments.app, name="experiments") # disable experiments
app.add_typer(skyplane.cli.experiments.app, name="experiments") # disable experiments
app.add_typer(skyplane.cli.cli_cloud.app, name="cloud")
app.add_typer(skyplane.cli.cli_config.app, name="config")

Expand Down
32 changes: 21 additions & 11 deletions skyplane/cli/cli_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,28 @@ def run_transfer(
register_exception_handler()
print_header()

provider_src, bucket_src, path_src = parse_path(src)
provider_dst, bucket_dst, path_dst = parse_path(dst)
provider_src, transfer_src, path_src = parse_path(src)
provider_dst, transfer_dst, path_dst = parse_path(dst)

# update planner for one-sided transfer
# somet process for other cloud providers with no VM support
assert provider_src != "cloudflare" or provider_dst != "cloudflare", "Cannot transfer between two Cloudflare buckets"
if provider_src == "cloudflare":
solver = "dst_one_sided"
elif provider_dst == "cloudflare":
solver = "src_one_sided"

src_region_tag = StorageInterface.create(f"{provider_src}:infer", bucket_src).region_tag()
dst_region_tag = StorageInterface.create(f"{provider_dst}:infer", bucket_dst).region_tag()
if provider_src == "vm" and provider_dst == "vm":
solver = "vm_to_vm"
elif provider_src == "vm":
solver = "vm_source"
elif provider_dst == "vm":
solver = "vm_dest"
else:
# the previous handling for non-VM transfers
assert provider_src != "cloudflare" or provider_dst != "cloudflare", "Cannot transfer between two Cloudflare buckets"
if provider_src == "cloudflare":
solver = "dst_one_sided"
elif provider_dst == "cloudflare":
solver = "src_one_sided"

src_region_tag = StorageInterface.create(f"{provider_src}:infer", transfer_src).region_tag()
dst_region_tag = StorageInterface.create(f"{provider_dst}:infer", transfer_dst).region_tag()

args = {
"cmd": cmd,
"recursive": True,
Expand Down Expand Up @@ -371,7 +380,8 @@ def run_transfer(
# fallback option: transfer is too small
if cli.args["cmd"] == "cp":
job = CopyJob(src, [dst], recursive=recursive) # TODO: rever to using pipeline
if cli.estimate_small_transfer(job, cloud_config.get_flag("native_cmd_threshold_gb") * GB):
if cli.estimate_small_transfer(job, 0.01 * GB): # Test small transfer
# if cli.estimate_small_transfer(job, cloud_config.get_flag("native_cmd_threshold_gb") * GB):
small_transfer_status = cli.transfer_cp_small(src, dst, recursive)
return 0 if small_transfer_status else 1
else:
Expand Down
9 changes: 8 additions & 1 deletion skyplane/cli/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import typer

from skyplane.cli.experiments.cli_profile import latency_grid, throughput_grid
from skyplane.cli.experiments.cli_query import get_max_throughput, util_grid_throughput, util_grid_cost, dump_full_util_cost_grid
from skyplane.cli.experiments.cli_query import (
get_max_throughput,
util_grid_throughput,
util_grid_cost,
dump_full_util_cost_grid,
)
from skyplane.cli.experiments.cli_create_instance import create_instance

app = typer.Typer(name="experiments")
app.command()(latency_grid)
Expand All @@ -10,3 +16,4 @@
app.command()(util_grid_throughput)
app.command()(util_grid_cost)
app.command()(dump_full_util_cost_grid)
app.command()(create_instance)
Loading