From 329c5259d2727ef7af0435d00d613ed26682a30f Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 12 May 2025 19:04:28 -0700 Subject: [PATCH 01/21] Initial commit for spin steps --- metaflow/cli.py | 53 ++- metaflow/cli_components/run_cmds.py | 160 ++++++++- metaflow/cli_components/step_cmd.py | 167 ++++++++- metaflow/client/core.py | 238 ++++++++----- metaflow/client/filecache.py | 90 +++-- metaflow/datastore/__init__.py | 1 + metaflow/datastore/content_addressed_store.py | 42 ++- metaflow/datastore/datastore_set.py | 11 +- metaflow/datastore/flow_datastore.py | 125 ++++++- metaflow/datastore/spin_datastore.py | 91 +++++ metaflow/datastore/task_datastore.py | 86 ++++- metaflow/metaflow_config.py | 12 + metaflow/metaflow_profile.py | 18 + metaflow/plugins/cards/card_decorator.py | 1 + metaflow/runner/metaflow_runner.py | 206 ++++++++++- metaflow/runtime.py | 336 ++++++++++++++++-- metaflow/task.py | 60 +++- metaflow/util.py | 3 +- .../unit/spin/artifacts/complex_dag_step_a.py | 1 + .../unit/spin/artifacts/complex_dag_step_d.py | 11 + test/unit/spin/complex_dag_flow.py | 116 ++++++ test/unit/spin/merge_artifacts_flow.py | 63 ++++ test/unit/spin/test_spin.py | 138 +++++++ 23 files changed, 1813 insertions(+), 216 deletions(-) create mode 100644 metaflow/datastore/spin_datastore.py create mode 100644 test/unit/spin/artifacts/complex_dag_step_a.py create mode 100644 test/unit/spin/artifacts/complex_dag_step_d.py create mode 100644 test/unit/spin/complex_dag_flow.py create mode 100644 test/unit/spin/merge_artifacts_flow.py create mode 100644 test/unit/spin/test_spin.py diff --git a/metaflow/cli.py b/metaflow/cli.py index cb9a0bc1ac9..5e5f95a0ab2 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -1,3 +1,4 @@ +import os import functools import inspect import os @@ -7,7 +8,6 @@ import metaflow.tracing as tracing from metaflow._vendor import click -from metaflow.system import _system_logger, _system_monitor from . import decorators, lint, metaflow_version, parameters, plugins from .cli_args import cli_args @@ -25,8 +25,12 @@ DEFAULT_METADATA, DEFAULT_MONITOR, DEFAULT_PACKAGE_SUFFIXES, + DATASTORE_SYSROOT_SPIN, + DATASTORE_LOCAL_DIR, ) from .metaflow_current import current +from .metaflow_profile import from_start +from metaflow.system import _system_monitor, _system_logger from .metaflow_environment import MetaflowEnvironment from .packaging_sys import MetaflowCodeContent from .plugins import ( @@ -38,6 +42,7 @@ ) from .pylint_wrapper import PyLint from .R import metaflow_r_version, use_r +from .util import get_latest_run_id, resolve_identity from .user_configs.config_options import LocalFileInput, config_options from .user_configs.config_parameters import ConfigValue from .util import get_latest_run_id, resolve_identity @@ -125,6 +130,8 @@ def logger(body="", system_msg=False, head="", bad=False, timestamp=True, nl=Tru "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-step": "metaflow.cli_components.step_cmd.spin_step", }, ) def cli(ctx): @@ -347,6 +354,7 @@ def start( if use_r(): version = metaflow_r_version() + from_start("MetaflowCLI: Starting") echo("Metaflow %s" % version, fg="magenta", bold=True, nl=False) echo(" executing *%s*" % ctx.obj.flow.name, fg="magenta", nl=False) echo(" for *%s*" % resolve_identity(), fg="magenta") @@ -498,6 +506,45 @@ def start( ) ctx.obj.config_options = config_options + ctx.obj.is_spin = False + + # Override values for spin + if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: + # To minimize side-effects for spin, we will only use the following: + # - local metadata provider, + # - local datastore, + # - local environment, + # - null event logger, + # - null monitor + ctx.obj.is_spin = True + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + # Set datastore_root to be DATASTORE_SYSROOT_SPIN if not provided + datastore_root = os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) + ctx.obj.datastore_impl.datastore_root = datastore_root + ctx.obj.flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, # Same environment as run/resume + ctx.obj.metadata, # local metadata + ctx.obj.event_logger, # null event logger + ctx.obj.monitor, # null monitor + storage_impl=ctx.obj.datastore_impl, + ) + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) decorators._init(ctx.obj.flow) @@ -514,7 +561,7 @@ def start( deco_options, ) - # In the case of run/resume, we will want to apply the TL decospecs + # In the case of run/resume/spin, we will want to apply the TL decospecs # *after* the run decospecs so that they don't take precedence. In other # words, for the same decorator, we want `myflow.py run --with foo` to # take precedence over any other `foo` decospec @@ -542,7 +589,7 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): # run/resume are special cases because they can add more decorators with --with, # so they have to take care of themselves. diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 159e2764303..91a06e0137a 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -9,8 +9,9 @@ from ..graph import FlowGraph from ..metaflow_current import current from ..metaflow_config import DEFAULT_DECOSPECS, FEAT_ALWAYS_UPLOAD_CODE_PACKAGE +from ..metaflow_profile import from_start from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger # from ..client.core import Run @@ -22,7 +23,7 @@ def before_run(obj, tags, decospecs): validate_tags(tags) - # There's a --with option both at the top-level and for the run + # There's a --with option both at the top-level and for the run/resume/spin # subcommand. Why? # # "run --with shoes" looks so much better than "--with shoes run". @@ -42,12 +43,12 @@ def before_run(obj, tags, decospecs): + list(obj.environment.decospecs() or []) ) if all_decospecs: - # These decospecs are the ones from run/resume PLUS the ones from the + # These decospecs are the ones from run/resume/spin PLUS the ones from the # environment (for example the @conda) decorators._attach_decorators(obj.flow, all_decospecs) decorators._init(obj.flow) # Regenerate graph if we attached more decorators - obj.flow.__class__._init_graph() + obj.flow.__class__._init_attrs() obj.graph = obj.flow._graph obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) @@ -73,6 +74,29 @@ def before_run(obj, tags, decospecs): ) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally " + "for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def write_file(file_path, content): if file_path is not None: with open(file_path, "w", encoding="utf-8") as f: @@ -137,20 +161,6 @@ def common_run_options(func): "in steps.", callback=config_callback, ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -195,6 +205,7 @@ def wrapper(*args, **kwargs): @click.command(help="Resume execution of a previous run of this flow.") @tracing.cli("cli/resume") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -326,6 +337,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -348,7 +360,7 @@ def run( run_id_file=None, runner_attribute_file=None, user_namespace=None, - **kwargs + **kwargs, ): if user_namespace is not None: namespace(user_namespace or None) @@ -401,3 +413,113 @@ def run( ) with runtime.run_heartbeat(): runtime.execute() + + +@parameters.add_custom_parameters(deploy_mode=True) +@click.command(help="Spins up a task for a given step from a previous run locally.") +@click.argument("step-name") +@click.option( + "--spin-pathspec", + default=None, + type=str, + help="Use specified task pathspec from a previous run to spin up the step.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. " + "The artifacts should be defined as a dictionary called ARTIFACTS with keys as " + "the artifact names and values as the artifact values. The artifact values will " + "overwrite the default values of the artifacts used in the spun step.", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to False, " + "the artifacts will notbe persisted and will not be available in the spun step's " + "datastore.", +) +@click.option( + "--max-log-size", + default=10, + show_default=True, + help="Maximum size of stdout and stderr captured in " + "megabytes. If a step outputs more than this to " + "stdout/stderr, its output will be truncated.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + step_name, + spin_pathspec=None, + persist=True, + artifacts_module=None, + skip_decorators=False, + max_log_size=None, + run_id_file=None, + runner_attribute_file=None, + **kwargs, +): + before_run(obj, [], []) + obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name, None) + if step_func is None: + raise CommandException( + f"Step '{step_name}' not found in flow '{obj.flow.name}'. " + "Please provide a valid step name." + ) + from_start("Spin: before spin runtime init") + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.flow_datastore, + obj.metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + step_func, + step_name, + spin_pathspec, + skip_decorators, + artifacts_module, + persist, + max_log_size * 1024 * 1024, + ) + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) + # datastore_root is os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) + # We only need the root for the metadata, i.e. the portion before DATASTORE_LOCAL_DIR + datastore_root = spin_runtime._flow_datastore._storage_impl.datastore_root + orig_task_metadata_root = datastore_root.rsplit("/", 1)[0] + from_start("Spin: going to execute") + spin_runtime.execute() + from_start("Spin: after spin runtime execute") + + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + # Store metadata in a format that can be used by the Runner API + "metadata": f"{obj.metadata.__class__.TYPE}@{orig_task_metadata_root}", + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index f4bef099e42..309a082ee68 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -1,12 +1,16 @@ from metaflow._vendor import click -from .. import decorators, namespace +from .. import namespace from ..cli import echo_always, echo_dev_null from ..cli_args import cli_args +from ..datastore.flow_datastore import FlowDataStore from ..exception import CommandException +from ..client.filecache import FileCache, FileBlobCache, TaskMetadataCache +from ..metaflow_profile import from_start +from ..plugins import DATASTORES from ..task import MetaflowTask from ..unbounded_foreach import UBF_CONTROL, UBF_TASK -from ..util import decompress_list +from ..util import decompress_list, read_artifacts_module import metaflow.tracing as tracing @@ -109,7 +113,6 @@ def step( ubf_context="none", num_parallel=None, ): - if ctx.obj.is_quiet: echo = echo_dev_null else: @@ -176,3 +179,161 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--orig-flow-datastore", + default=None, + show_default=True, + help="Original datastore for the flow from which a task is being spun", +) +@click.option( + "--spin-pathspec", + default=None, + show_default=True, + help="Task Pathspec to be used in the spun step.", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "opt_namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.option( + "--whitelist-decorators", + help="A comma-separated list of whitelisted decorators to use for the spin step", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to false, the artifacts will not" + " be persisted and will not be available in the spun step's datastore.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. The artifacts should " + "be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the " + "artifact values. The artifact values will overwrite the default values of the artifacts used in " + "the spun step.", +) +@click.pass_context +def spin_step( + ctx, + step_name, + run_id=None, + task_id=None, + orig_flow_datastore=None, + spin_pathspec=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + opt_namespace=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, +): + import time + + if ctx.obj.is_quiet: + echo = echo_dev_null + else: + echo = echo_always + + if opt_namespace is not None: + namespace(opt_namespace or None) + + input_paths = decompress_list(input_paths) if input_paths else [] + + whitelist_decorators = ( + decompress_list(whitelist_decorators) if whitelist_decorators else [] + ) + from_start("SpinStep: initialized decorators") + spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} + from_start("SpinStep: read artifacts module") + + ds_type, ds_root = orig_flow_datastore.split("@") + orig_datastore_impl = [d for d in DATASTORES if d.TYPE == ds_type][0] + orig_datastore_impl.datastore_root = ds_root + orig_flow_datastore = FlowDataStore( + ctx.obj.flow.name, + environment=None, + storage_impl=orig_datastore_impl, + ds_root=ds_root, + ) + + filecache = FileCache() + orig_flow_datastore.set_metadata_cache( + TaskMetadataCache(filecache, ds_type, ds_root, ctx.obj.flow.name) + ) + orig_flow_datastore.ca_store.set_blob_cache( + FileBlobCache( + filecache, FileCache.flow_ds_id(ds_type, ds_root, ctx.obj.flow.name) + ) + ) + + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, + ctx.obj.metadata, + ctx.obj.environment, + echo, + ctx.obj.event_logger, + ctx.obj.monitor, + None, # no unbounded foreach context + orig_flow_datastore=orig_flow_datastore, + spin_artifacts=spin_artifacts, + ) + from_start("SpinStep: initialized task") + task.run_step( + step_name, + run_id, + task_id, + None, + input_paths, + split_index, + retry_count, + max_user_code_retries, + whitelist_decorators, + persist, + ) + from_start("SpinStep: ran step") diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 7fa9dcb4018..6a8f6f3af19 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1189,149 +1189,191 @@ class Task(MetaflowObject): _PARENT_CLASS = "step" _CHILD_CLASS = "artifact" - def __init__(self, *args, **kwargs): - super(Task, self).__init__(*args, **kwargs) - def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def _iter_matching_tasks(self, steps, metadata_key, metadata_pattern): + def _get_matching_pathspecs(self, steps, metadata_key, metadata_pattern): """ - Yield tasks from specified steps matching a foreach path pattern. + Yield pathspecs of tasks from specified steps that match a given metadata pattern. Parameters ---------- steps : List[str] - List of step names to search for tasks - pattern : str - Regex pattern to match foreach-indices metadata + List of Step objects to search for tasks. + metadata_key : str + Metadata key to filter tasks on (e.g., 'foreach-execution-path'). + metadata_pattern : str + Regular expression pattern to match against the metadata value. - Returns - ------- - Iterator[Task] - Tasks matching the foreach path pattern + Yields + ------ + str + Pathspec of each task whose metadata value for the specified key matches the pattern. """ flow_id, run_id, _, _ = self.path_components - for step in steps: task_pathspecs = self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step.id, metadata_key, metadata_pattern + flow_id, run_id, step, metadata_key, metadata_pattern ) for task_pathspec in task_pathspecs: - yield Task(pathspec=task_pathspec, _namespace_check=False) + yield task_pathspec + + @staticmethod + def _get_previous_steps(graph_info, step_name): + # Get the parent steps + steps = [] + for node_name, attributes in graph_info["steps"].items(): + if step_name in attributes["next"]: + steps.append(node_name) + return steps @property - def parent_tasks(self) -> Iterator["Task"]: + def parent_task_pathspecs(self) -> Iterator[str]: """ - Yields all parent tasks of the current task if one exists. + Yields pathspecs of all parent tasks of the current task. Yields ------ - Task - Parent task of the current task - + str + Pathspec of the parent task of the current task """ - flow_id, run_id, _, _ = self.path_components - - steps = list(self.parent.parent_steps) - if not steps: - return [] + _, _, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data - current_path = self.metadata_dict.get("foreach-execution-path", "") + # Get the parent steps + steps = self._get_previous_steps(graph_info, step_name) + node_type = graph_info["steps"][step_name]["type"] + current_path = metadata_dict.get("foreach-execution-path") if len(steps) > 1: # Static join - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Parent task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach join - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No parent steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach split or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) + current_depth = len(current_path.split(",")) + if node_type == "join": + # Foreach join + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach split or linear step + # Pattern: "A:10,B:13" + parent_step_type = graph_info["steps"][steps[0]]["type"] + target_depth = current_depth + if parent_step_type == "split-foreach" and current_depth == 1: + # (Current task, "A:10") and (Parent task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") + if parent_step_type == "split-foreach": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + metadata_key = "foreach-execution-path" + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec @property - def child_tasks(self) -> Iterator["Task"]: + def child_task_pathspecs(self) -> Iterator[str]: """ - Yield all child tasks of the current task if one exists. + Yields pathspecs of all child tasks of the current task. Yields ------ - Task - Child task of the current task + str + Pathspec of the child task of the current task """ - flow_id, run_id, _, _ = self.path_components - steps = list(self.parent.child_steps) - if not steps: - return [] + flow_id, run_id, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data + + # Get the child steps + steps = graph_info["steps"][step_name]["next"] - current_path = self.metadata_dict.get("foreach-execution-path", "") + node_type = graph_info["steps"][step_name]["type"] + current_path = self.metadata_dict.get("foreach-execution-path") if len(steps) > 1: # Static split - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Child task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach split - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No child steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach join or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) - - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + current_depth = len(current_path.split(",")) + if node_type == "split-foreach": + # Foreach split + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach join or linear step + # Pattern: "A:10,B:13" + child_step_type = graph_info["steps"][steps[0]]["type"] + + # We need to know if the child step is a foreach join or a static join + child_step_prev_steps = self._get_previous_steps( + graph_info, steps[0] + ) + if len(child_step_prev_steps) > 1: + child_step_type = "static-join" + target_depth = current_depth + if child_step_type == "join" and current_depth == 1: + # (Current task, "A:10") and (Child task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") + if child_step_type == "join": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) + + metadata_key = "foreach-execution-path" + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec + + @property + def parent_tasks(self) -> Iterator["Task"]: + """ + Yields all parent tasks of the current task if one exists. + + Yields + ------ + Task + Parent task of the current task + """ + parent_task_pathspecs = self.parent_task_pathspecs + for pathspec in parent_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) + + @property + def child_tasks(self) -> Iterator["Task"]: + """ + Yields all child tasks of the current task if one exists. + + Yields + ------ + Task + Child task of the current task + """ + for pathspec in self.child_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) @property def metadata(self) -> List[Metadata]: diff --git a/metaflow/client/filecache.py b/metaflow/client/filecache.py index 83a38811eff..980b5f34cf0 100644 --- a/metaflow/client/filecache.py +++ b/metaflow/client/filecache.py @@ -1,5 +1,6 @@ from __future__ import print_function from collections import OrderedDict +import json import os import sys import time @@ -10,13 +11,14 @@ from metaflow.datastore import FlowDataStore from metaflow.datastore.content_addressed_store import BlobCache +from metaflow.datastore.flow_datastore import MetadataCache from metaflow.exception import MetaflowException from metaflow.metaflow_config import ( CLIENT_CACHE_PATH, CLIENT_CACHE_MAX_SIZE, CLIENT_CACHE_MAX_FLOWDATASTORE_COUNT, - CLIENT_CACHE_MAX_TASKDATASTORE_COUNT, ) +from metaflow.metaflow_profile import from_start from metaflow.plugins import DATASTORES @@ -63,8 +65,8 @@ def __init__(self, cache_dir=None, max_size=None): # when querying for sizes of artifacts. Once we have queried for the size # of one artifact in a TaskDatastore, caching this means that any # queries on that same TaskDatastore will be quick (since we already - # have all the metadata) - self._task_metadata_caches = OrderedDict() + # have all the metadata). We keep track of this in a file so it persists + # across processes. @property def cache_dir(self): @@ -87,7 +89,7 @@ def get_log_legacy( ): ds_cls = self._get_datastore_storage_impl(ds_type) ds_root = ds_cls.path_join(*ds_cls.path_split(location)[:-5]) - cache_id = self._flow_ds_id(ds_type, ds_root, flow_name) + cache_id = self.flow_ds_id(ds_type, ds_root, flow_name) token = ( "%s.cached" @@ -311,13 +313,13 @@ def _index_objects(self): self._objects = sorted(objects, reverse=False) @staticmethod - def _flow_ds_id(ds_type, ds_root, flow_name): + def flow_ds_id(ds_type, ds_root, flow_name): p = urlparse(ds_root) sanitized_root = (p.netloc + p.path).replace("/", "_") return ".".join([ds_type, sanitized_root, flow_name]) @staticmethod - def _task_ds_id(ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt): + def task_ds_id(ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt): p = urlparse(ds_root) sanitized_root = (p.netloc + p.path).replace("/", "_") return ".".join( @@ -365,7 +367,7 @@ def _get_datastore_storage_impl(ds_type): return storage_impl[0] def _get_flow_datastore(self, ds_type, ds_root, flow_name): - cache_id = self._flow_ds_id(ds_type, ds_root, flow_name) + cache_id = self.flow_ds_id(ds_type, ds_root, flow_name) cached_flow_datastore = self._store_caches.get(cache_id) if cached_flow_datastore: @@ -380,9 +382,14 @@ def _get_flow_datastore(self, ds_type, ds_root, flow_name): ds_root=ds_root, ) blob_cache = self._blob_caches.setdefault( - cache_id, FileBlobCache(self, cache_id) + cache_id, + ( + FileBlobCache(self, cache_id), + TaskMetadataCache(self, ds_type, ds_root, flow_name), + ), ) - cached_flow_datastore.ca_store.set_blob_cache(blob_cache) + cached_flow_datastore.ca_store.set_blob_cache(blob_cache[0]) + cached_flow_datastore.set_metadata_cache(blob_cache[1]) self._store_caches[cache_id] = cached_flow_datastore if len(self._store_caches) > CLIENT_CACHE_MAX_FLOWDATASTORE_COUNT: cache_id_to_remove, _ = self._store_caches.popitem(last=False) @@ -393,32 +400,49 @@ def _get_task_datastore( self, ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt ): flow_ds = self._get_flow_datastore(ds_type, ds_root, flow_name) - cached_metadata = None - if attempt is not None: - cache_id = self._task_ds_id( - ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt - ) - cached_metadata = self._task_metadata_caches.get(cache_id) - if cached_metadata: - od_move_to_end(self._task_metadata_caches, cache_id) - return flow_ds.get_task_datastore( - run_id, - step_name, - task_id, - attempt=attempt, - data_metadata=cached_metadata, - ) - # If we are here, we either have attempt=None or nothing in the cache - task_ds = flow_ds.get_task_datastore( - run_id, step_name, task_id, attempt=attempt + + return flow_ds.get_task_datastore(run_id, step_name, task_id, attempt=attempt) + + +class TaskMetadataCache(MetadataCache): + def __init__(self, filecache, ds_type, ds_root, flow_name): + self._ds_type = ds_type + self._ds_root = ds_root + self._flow_name = flow_name + self._filecache = filecache + + def _path(self, run_id, step_name, task_id, attempt): + if attempt is None: + return None + cache_id = self._filecache.task_ds_id( + self._ds_type, + self._ds_root, + self._flow_name, + run_id, + step_name, + task_id, + attempt, + ) + token = ( + "%s.cached" + % sha1( + os.path.join( + run_id, step_name, task_id, str(attempt), "metadata" + ).encode("utf-8") + ).hexdigest() ) - cache_id = self._task_ds_id( - ds_type, ds_root, flow_name, run_id, step_name, task_id, task_ds.attempt + return os.path.join(self._filecache.cache_dir, cache_id, token[:2], token) + + def load_metadata(self, run_id, step_name, task_id, attempt): + d = self._filecache.read_file(self._path(run_id, step_name, task_id, attempt)) + if d: + return json.loads(d) + + def store_metadata(self, run_id, step_name, task_id, attempt, metadata_dict): + self._filecache.create_file( + self._path(run_id, step_name, task_id, attempt), + json.dumps(metadata_dict).encode("utf-8"), ) - self._task_metadata_caches[cache_id] = task_ds.ds_metadata - if len(self._task_metadata_caches) > CLIENT_CACHE_MAX_TASKDATASTORE_COUNT: - self._task_metadata_caches.popitem(last=False) - return task_ds class FileBlobCache(BlobCache): diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..65bb33b0eb9 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,4 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore import SpinTaskDatastore diff --git a/metaflow/datastore/content_addressed_store.py b/metaflow/datastore/content_addressed_store.py index e0533565ffa..75203174d9d 100644 --- a/metaflow/datastore/content_addressed_store.py +++ b/metaflow/datastore/content_addressed_store.py @@ -38,7 +38,7 @@ def __init__(self, prefix, storage_impl): def set_blob_cache(self, blob_cache): self._blob_cache = blob_cache - def save_blobs(self, blob_iter, raw=False, len_hint=0): + def save_blobs(self, blob_iter, raw=False, len_hint=0, _is_transfer=False): """ Saves blobs of data to the datastore @@ -65,6 +65,9 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0): Whether to save the bytes directly or process them, by default False len_hint : Hint of the number of blobs that will be produced by the iterator, by default 0 + _is_transfer : bool, default False + If True, this indicates we are saving blobs directly from the output of another + content addressed store's Returns ------- @@ -76,6 +79,20 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0): def packing_iter(): for blob in blob_iter: + if _is_transfer: + key, blob_data, meta = blob + path = self._storage_impl.path_join(self._prefix, key[:2], key) + # Transfer data is always raw/decompressed, so mark it as such + meta_corrected = {"cas_raw": True, "cas_version": 1} + + results.append( + self.save_blobs_result( + uri=self._storage_impl.full_uri(path), + key=key, + ) + ) + yield path, (BytesIO(blob_data), meta_corrected) + continue sha = sha1(blob).hexdigest() path = self._storage_impl.path_join(self._prefix, sha[:2], sha) results.append( @@ -100,7 +117,7 @@ def packing_iter(): self._storage_impl.save_bytes(packing_iter(), overwrite=True, len_hint=len_hint) return results - def load_blobs(self, keys, force_raw=False): + def load_blobs(self, keys, force_raw=False, _is_transfer=False): """ Mirror function of save_blobs @@ -111,15 +128,20 @@ def load_blobs(self, keys, force_raw=False): ---------- keys : List of string Key describing the object to load - force_raw : bool, optional + force_raw : bool, default False Support for backward compatibility with previous datastores. If True, this will force the key to be loaded as is (raw). By default, False + _is_transfer : bool, default False + If True, this indicates we are loading blobs to transfer them directly + to another datastore. We will, in this case, also transfer the metdata + and do minimal processing. This is for internal use only. Returns ------- Returns an iterator of (string, bytes) tuples; the iterator may return keys - in a different order than were passed in. + in a different order than were passed in. If _is_transfer is True, the tuple + has three elements with the third one being the metadata. """ load_paths = [] for key in keys: @@ -127,13 +149,18 @@ def load_blobs(self, keys, force_raw=False): if self._blob_cache: blob = self._blob_cache.load_key(key) if blob is not None: - yield key, blob + if _is_transfer: + # Cached blobs are decompressed/processed bytes regardless of original format + yield key, blob, {"cas_raw": False, "cas_version": 1} + else: + yield key, blob else: path = self._storage_impl.path_join(self._prefix, key[:2], key) load_paths.append((key, path)) with self._storage_impl.load_bytes([p for _, p in load_paths]) as loaded: for path_key, file_path, meta in loaded: + print(f"path_key: {path_key}, file_path: {file_path}, meta: {meta}") key = self._storage_impl.path_split(path_key)[-1] # At this point, we either return the object as is (if raw) or # decode it according to the encoding version @@ -169,7 +196,10 @@ def load_blobs(self, keys, force_raw=False): if self._blob_cache: self._blob_cache.store_key(key, blob) - yield key, blob + if _is_transfer: + yield key, blob, meta # Preserve exact original metadata from storage + else: + yield key, blob def _unpack_backward_compatible(self, blob): # This is the backward compatible unpack diff --git a/metaflow/datastore/datastore_set.py b/metaflow/datastore/datastore_set.py index f60642de73f..80cc4c690a4 100644 --- a/metaflow/datastore/datastore_set.py +++ b/metaflow/datastore/datastore_set.py @@ -21,9 +21,18 @@ def __init__( pathspecs=None, prefetch_data_artifacts=None, allow_not_done=False, + join_type=None, + orig_flow_datastore=None, + spin_artifacts=None, ): self.task_datastores = flow_datastore.get_task_datastores( - run_id, steps=steps, pathspecs=pathspecs, allow_not_done=allow_not_done + run_id, + steps=steps, + pathspecs=pathspecs, + allow_not_done=allow_not_done, + join_type=join_type, + orig_flow_datastore=orig_flow_datastore, + spin_artifacts=spin_artifacts, ) if prefetch_data_artifacts: diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 16318ed7693..43cd6e8bc14 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -5,6 +5,8 @@ from .content_addressed_store import ContentAddressedStore from .task_datastore import TaskDataStore +from .spin_datastore import SpinTaskDatastore +from ..metaflow_profile import from_start class FlowDataStore(object): @@ -63,10 +65,16 @@ def __init__( self._storage_impl.path_join(self.flow_name, "data"), self._storage_impl ) + # Private + self._metadata_cache = None + @property def datastore_root(self): return self._storage_impl.datastore_root + def set_metadata_cache(self, cache): + self._metadata_cache = cache + def get_task_datastores( self, run_id=None, @@ -76,6 +84,9 @@ def get_task_datastores( attempt=None, include_prior=False, mode="r", + join_type=None, + orig_flow_datastore=None, + spin_artifacts=None, ): """ Return a list of TaskDataStore for a subset of the tasks. @@ -95,7 +106,7 @@ def get_task_datastores( Steps to get the tasks from. If run_id is specified, this must also be specified, by default None pathspecs : List[str], optional - Full task specs (run_id/step_name/task_id). Can be used instead of + Full task specs (run_id/step_name/task_id[/attempt]). Can be used instead of specifying run_id and steps, by default None allow_not_done : bool, optional If True, returns the latest attempt of a task even if that attempt @@ -106,6 +117,16 @@ def get_task_datastores( If True, returns all attempts up to and including attempt. mode : str, default "r" Mode to initialize the returned TaskDataStores in. + join_type : str, optional + If specified, the join type for the task. This is used to determine + the user specified artifacts for the task in case of a spin task. + orig_flow_datastore : MetadataProvider, optional + The metadata provider in case of a spin task. If provided, the + returned TaskDataStore will be a SpinTaskDatastore instead of a + TaskDataStore. + spin_artifacts : Dict[str, Any], optional + Artifacts provided by user that can override the artifacts fetched via the + spin pathspec. Returns ------- @@ -145,7 +166,13 @@ def get_task_datastores( if attempt is not None and attempt <= metaflow_config.MAX_ATTEMPTS - 1: attempt_range = range(attempt + 1) if include_prior else [attempt] for task_url in task_urls: - for attempt in attempt_range: + task_splits = task_url.split("/") + # Usually it is flow, run, step, task (so 4 components) -- if we have a + # fifth one, there is a specific attempt number listed as well. + task_attempt_range = attempt_range + if len(task_splits) == 5: + task_attempt_range = [int(task_splits[4])] + for attempt in task_attempt_range: for suffix in [ TaskDataStore.METADATA_DATA_SUFFIX, TaskDataStore.METADATA_ATTEMPT_SUFFIX, @@ -198,7 +225,18 @@ def get_task_datastores( else (latest_started_attempts & done_attempts) ) latest_to_fetch = [ - (v[0], v[1], v[2], v[3], data_objs.get(v), mode, allow_not_done) + ( + v[0], + v[1], + v[2], + v[3], + data_objs.get(v), + mode, + allow_not_done, + join_type, + orig_flow_datastore, + spin_artifacts, + ) for v in latest_to_fetch ] return list(itertools.starmap(self.get_task_datastore, latest_to_fetch)) @@ -212,8 +250,64 @@ def get_task_datastore( data_metadata=None, mode="r", allow_not_done=False, + join_type=None, + orig_flow_datastore=None, + spin_artifacts=None, + persist=True, ): - return TaskDataStore( + if orig_flow_datastore is not None: + # In spin step subprocess, use SpinTaskDatastore for accessing artifacts + if join_type is not None: + # If join_type is specified, we need to use the artifacts corresponding + # to that particular join index, specified by the parent task pathspec. + spin_artifacts = spin_artifacts.get( + f"{run_id}/{step_name}/{task_id}", {} + ) + from_start( + "FlowDataStore: get_task_datastore for spin task for type %s %s metadata" + % (self.TYPE, "without" if data_metadata is None else "with") + ) + # Get the task datastore for the spun task. + orig_datastore = orig_flow_datastore.get_task_datastore( + run_id, + step_name, + task_id, + attempt=attempt, + data_metadata=data_metadata, + mode=mode, + allow_not_done=allow_not_done, + join_type=join_type, + persist=persist, + ) + + return SpinTaskDatastore( + self.flow_name, + run_id, + step_name, + task_id, + orig_datastore, + spin_artifacts, + ) + + cache_hit = False + if ( + self._metadata_cache is not None + and data_metadata is None + and attempt is not None + and allow_not_done is False + ): + # If we have a metadata cache, we can try to load the metadata + # from the cache if it is not provided. + data_metadata = self._metadata_cache.load_metadata( + run_id, step_name, task_id, attempt + ) + cache_hit = data_metadata is not None + + from_start( + "FlowDataStore: get_task_datastore for regular task for type %s %s metadata" + % (self.TYPE, "without" if data_metadata is None else "with") + ) + task_datastore = TaskDataStore( self, run_id, step_name, @@ -222,8 +316,23 @@ def get_task_datastore( data_metadata=data_metadata, mode=mode, allow_not_done=allow_not_done, + persist=persist, ) + # Only persist in cache if it is non-changing (so done only) and we have + # a non-None attempt + if ( + not cache_hit + and self._metadata_cache is not None + and allow_not_done is False + and attempt is not None + ): + self._metadata_cache.store_metadata( + run_id, step_name, task_id, attempt, task_datastore.ds_metadata + ) + + return task_datastore + def save_data(self, data_iter, len_hint=0): """Saves data to the underlying content-addressed store @@ -265,3 +374,11 @@ def load_data(self, keys, force_raw=False): """ for key, blob in self.ca_store.load_blobs(keys, force_raw=force_raw): yield key, blob + + +class MetadataCache(object): + def load_metadata(self, run_id, step_name, task_id, attempt): + pass + + def store_metadata(self, run_id, step_name, task_id, attempt, metadata_dict): + pass diff --git a/metaflow/datastore/spin_datastore.py b/metaflow/datastore/spin_datastore.py new file mode 100644 index 00000000000..f45856c4c51 --- /dev/null +++ b/metaflow/datastore/spin_datastore.py @@ -0,0 +1,91 @@ +from typing import Dict, Any +from .task_datastore import TaskDataStore, require_mode +from ..metaflow_profile import from_start + + +class SpinTaskDatastore(object): + def __init__( + self, + flow_name: str, + run_id: str, + step_name: str, + task_id: str, + orig_datastore: TaskDataStore, + spin_artifacts: Dict[str, Any], + ): + """ + SpinTaskDatastore is a datastore for a task that is used to retrieve + artifacts and attributes for a spin step. It uses the task pathspec + from a previous execution of the step to access the artifacts and attributes. + + Parameters: + ----------- + flow_name : str + Name of the flow + run_id : str + Run ID of the flow + step_name : str + Name of the step + task_id : str + Task ID of the step + orig_datastore : TaskDataStore + The datastore for the underlying task that is being spun. + spin_artifacts : Dict[str, Any] + User provided artifacts that are to be used in the spin task. This is a dictionary + where keys are artifact names and values are the actual data or metadata. + """ + self.flow_name = flow_name + self.run_id = run_id + self.step_name = step_name + self.task_id = task_id + self.orig_datastore = orig_datastore + self.spin_artifacts = spin_artifacts + self._task = None + + # Update _objects and _info in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = self.orig_datastore._objects.copy() + self._info = self.orig_datastore._info.copy() + + # We strip out some of the control ones + for key in ("_transition",): + if key in self._objects: + del self._objects[key] + del self._info[key] + + from_start("SpinTaskDatastore: Initialized artifacts") + + @require_mode(None) + def __getitem__(self, name): + try: + # Check if it's an artifact in the spin_artifacts + return self.spin_artifacts[name] + except KeyError: + try: + # Check if it's an attribute of the task + # _foreach_stack, _foreach_index, ... + return self.orig_datastore[name] + except (KeyError, AttributeError) as e: + raise KeyError( + f"Attribute '{name}' not found in the previous execution " + f"of the tasks for `{self.step_name}`." + ) from e + + @require_mode(None) + def is_none(self, name): + val = self.__getitem__(name) + return val is None + + @require_mode(None) + def __contains__(self, name): + try: + _ = self.__getitem__(name) + return True + except KeyError: + return False + + @require_mode(None) + def items(self): + if self._objects: + return self._objects.items() + return {} diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index ebfed2d55d4..7ec20825b7b 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -98,6 +98,7 @@ def __init__( data_metadata=None, mode="r", allow_not_done=False, + persist=True, ): self._storage_impl = flow_datastore._storage_impl self.TYPE = self._storage_impl.TYPE @@ -113,6 +114,7 @@ def __init__( self._attempt = attempt self._metadata = flow_datastore.metadata self._parent = flow_datastore + self._persist = persist # The GZIP encodings are for backward compatibility self._encodings = {"pickle-v2", "gzip+pickle-v2"} @@ -148,6 +150,8 @@ def __init__( ) if self.has_metadata(check_meta, add_attempt=False): max_attempt = i + elif max_attempt is not None: + break if self._attempt is None: self._attempt = max_attempt elif max_attempt is None or self._attempt > max_attempt: @@ -253,6 +257,70 @@ def init_task(self): """ self.save_metadata({self.METADATA_ATTEMPT_SUFFIX: {"time": time.time()}}) + @only_if_not_done + @require_mode("w") + def transfer_artifacts(self, other_datastore, names=None): + """ + Copies the blobs from other_datastore to this datastore if the datastore roots + are different. + + This is used specifically for spin so we can bring in artifacts from the original + datastore. + + Parameters + ---------- + other_datastore : TaskDataStore + Other datastore from which to copy artifacts from + names : List[string], optional, default None + If provided, only transfer the artifacts with these names. If None, + transfer all artifacts from the other datastore. + """ + if ( + other_datastore.TYPE == self.TYPE + and other_datastore._storage_impl.datastore_root + == self._storage_impl.datastore_root + ): + # Nothing to transfer -- artifacts are already saved properly + return + + # Determine which artifacts need to be transferred + if names is None: + # Transfer all artifacts from other datastore + artifacts_to_transfer = list(other_datastore._objects.keys()) + else: + # Transfer only specified artifacts + artifacts_to_transfer = [ + name for name in names if name in other_datastore._objects + ] + + if not artifacts_to_transfer: + return + + # Get SHA keys for artifacts to transfer + shas_to_transfer = [ + other_datastore._objects[name] for name in artifacts_to_transfer + ] + + # Check which blobs are missing locally + missing_shas = [] + for sha in shas_to_transfer: + local_path = self._ca_store._storage_impl.path_join( + self._ca_store._prefix, sha[:2], sha + ) + if not self._ca_store._storage_impl.is_file([local_path])[0]: + missing_shas.append(sha) + + if not missing_shas: + return # All blobs already exist locally + + # Load blobs from other datastore in transfer mode + transfer_blobs = other_datastore._ca_store.load_blobs( + missing_shas, _is_transfer=True + ) + + # Save blobs to local datastore in transfer mode + self._ca_store.save_blobs(transfer_blobs, _is_transfer=True) + @only_if_not_done @require_mode("w") def save_artifacts(self, artifacts_iter, len_hint=0): @@ -683,14 +751,16 @@ def persist(self, flow): flow : FlowSpec Flow to persist """ + if not self._persist: + return if flow._datastore: self._objects.update(flow._datastore._objects) self._info.update(flow._datastore._info) - # we create a list of valid_artifacts in advance, outside of - # artifacts_iter, so we can provide a len_hint below + # Scan flow object FIRST valid_artifacts = [] + current_artifact_names = set() for var in dir(flow): if var.startswith("__") or var in flow._EPHEMERAL: continue @@ -707,6 +777,17 @@ def persist(self, flow): or isinstance(val, Parameter) ): valid_artifacts.append((var, val)) + current_artifact_names.add(var) + + # Transfer ONLY artifacts that aren't being overridden + if hasattr(flow._datastore, "orig_datastore"): + parent_artifacts = set(flow._datastore._objects.keys()) + unchanged_artifacts = parent_artifacts - current_artifact_names + print(f"Transferring unchanged artifacts: {unchanged_artifacts}") + if unchanged_artifacts: + self.transfer_artifacts( + flow._datastore.orig_datastore, names=list(unchanged_artifacts) + ) def artifacts_iter(): # we consume the valid_artifacts list destructively to @@ -722,6 +803,7 @@ def artifacts_iter(): delattr(flow, var) yield var, val + # Save current artifacts self.save_artifacts(artifacts_iter(), len_hint=len(valid_artifacts)) @only_if_not_done diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index bb42e0f5e3c..cde0b5b37f9 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -47,6 +47,13 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] +) + ### # User configuration ### @@ -57,6 +64,7 @@ # Datastore configuration ### DATASTORE_SYSROOT_LOCAL = from_conf("DATASTORE_SYSROOT_LOCAL") +DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN", "/tmp/metaflow") # S3 bucket and prefix to store artifacts for 's3' datastore. DATASTORE_SYSROOT_S3 = from_conf("DATASTORE_SYSROOT_S3") # Azure Blob Storage container and blob prefix @@ -464,6 +472,10 @@ ### FEAT_ALWAYS_UPLOAD_CODE_PACKAGE = from_conf("FEAT_ALWAYS_UPLOAD_CODE_PACKAGE", False) ### +# Profile +### +PROFILE_FROM_START = from_conf("PROFILE_FROM_START", False) +### # Debug configuration ### DEBUG_OPTIONS = [ diff --git a/metaflow/metaflow_profile.py b/metaflow/metaflow_profile.py index 39ecf42cdc3..1757aedf3fb 100644 --- a/metaflow/metaflow_profile.py +++ b/metaflow/metaflow_profile.py @@ -2,6 +2,24 @@ from contextlib import contextmanager +from .metaflow_config import PROFILE_FROM_START + +init_time = None + + +if PROFILE_FROM_START: + + def from_start(msg: str): + global init_time + if init_time is None: + init_time = time.time() + print("From start: %s took %dms" % (msg, int((time.time() - init_time) * 1000))) + +else: + + def from_start(_msg: str): + pass + @contextmanager def profile(label, stats_dict=None): diff --git a/metaflow/plugins/cards/card_decorator.py b/metaflow/plugins/cards/card_decorator.py index 28c0c7f8f10..daa667fa2a7 100644 --- a/metaflow/plugins/cards/card_decorator.py +++ b/metaflow/plugins/cards/card_decorator.py @@ -171,6 +171,7 @@ def step_init( self._flow_datastore = flow_datastore self._environment = environment self._logger = logger + self.card_options = None # We check for configuration options. We do this here before they are diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 84eb8ac4284..f2f142a4653 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -21,27 +21,33 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but instead user Runner.run() @@ -57,7 +63,6 @@ def __init__( """ self.runner = runner self.command_obj = command_obj - self.run = run_obj def __enter__(self) -> "ExecutingRun": return self @@ -193,6 +198,76 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead user Runner.spin() + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead user Runner.run() + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -275,7 +350,7 @@ def __init__( env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, - **kwargs + **kwargs, ): # these imports are required here and not at the top # since they interfere with the user defined Parameters @@ -397,6 +472,73 @@ def run(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + def spin(self, step_name, **kwargs): + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + Parameters + ---------- + step_name : str + The name of the step to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + def resume(self, **kwargs) -> ExecutingRun: """ Blocking resume execution of the run. @@ -510,6 +652,42 @@ async def async_resume(self, **kwargs) -> ExecutingRun: return await self.__async_get_executing_run(attribute_file_fd, command_obj) + async def async_spin(self, step_name, spin_pathspec, **kwargs) -> ExecutingTask: + """ + Non-blocking spin execution of the run. + This method will return as soon as the spun task has launched. + + Note that this method is asynchronous and needs to be `await`ed. + + Parameters + ---------- + step_name : str + The name of the step to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask representing the spun task that was started. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = await self.spm.async_run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + ) + command_obj = self.spm.get(pid) + + return await self.__async_get_executing_task(attribute_file_fd, command_obj) + def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup() diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 8f3864df3af..f54ec2789b7 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -26,20 +26,23 @@ from contextlib import contextmanager from . import get_namespace +from .client.filecache import FileCache, FileBlobCache, TaskMetadataCache from .metadata_provider import MetaDatum -from .metaflow_config import FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, MAX_ATTEMPTS, UI_URL +from .metaflow_config import FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, MAX_ATTEMPTS, UI_URL, SPIN_ALLOWED_DECORATORS +from .metaflow_profile import from_start +from .plugins import DATASTORES from .exception import ( MetaflowException, MetaflowInternalError, METAFLOW_EXIT_DISALLOW_RETRY, ) from . import procpoll -from .datastore import TaskDataStoreSet +from .datastore import FlowDataStore, TaskDataStoreSet from .debug import debug from .decorators import flow_decorators from .flowspec import _FlowState from .mflog import mflog, RUNTIME_LOG_SOURCE -from .util import to_unicode, compress_list, unicode_type +from .util import to_unicode, compress_list, unicode_type, get_latest_task_pathspec from .clone_util import clone_task_helper from .unbounded_foreach import ( CONTROL_TASK_TAG, @@ -85,6 +88,226 @@ class LoopBehavior(Enum): # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + step_func, + step_name, + spin_pathspec, + skip_decorators=False, + artifacts_module=None, + persist=True, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + + self._step_func = step_func + + # Verify whether the use has provided step-name or spin-pathspec + if not spin_pathspec: + task = get_latest_task_pathspec(flow.name, step_name) + logger("For faster spin, use --spin-pathspec %s" % task.pathspec) + else: + # The user already provided a spin-pathspec, verify if its valid + try: + task = Task(spin_pathspec, _namespace_check=False) + except Exception: + raise MetaflowException( + f"Invalid spin-pathspec: {spin_pathspec} for step: {step_name}" + ) + from_start("SpinRuntime: after getting task") + + # Get the original FlowDatastore so we can use it to access artifacts from the + # spun task + meta_dict = task.metadata_dict + ds_type = meta_dict["ds-type"] + ds_root = meta_dict["ds-root"] + orig_datastore_impl = [d for d in DATASTORES if d.TYPE == ds_type][0] + orig_datastore_impl.datastore_root = ds_root + spin_pathspec = task.pathspec + orig_flow_datastore = FlowDataStore( + flow.name, + environment=None, + storage_impl=orig_datastore_impl, + ds_root=ds_root, + ) + + self._filecache = FileCache() + orig_flow_datastore.set_metadata_cache( + TaskMetadataCache(self._filecache, ds_type, ds_root, flow.name) + ) + orig_flow_datastore.ca_store.set_blob_cache( + FileBlobCache( + self._filecache, FileCache.flow_ds_id(ds_type, ds_root, flow.name) + ) + ) + + self._orig_flow_datastore = orig_flow_datastore + self._spin_pathspec = spin_pathspec + self._persist = persist + self._spin_task = task + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._skip_decorators = skip_decorators + self._artifacts_module = artifacts_module + self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" + + # Create a new run_id for the spin task + self.run_id = self._metadata.new_run_id() + for deco in self.whitelist_decorators: + deco.runtime_init(flow, graph, package, self.run_id) + from_start("SpinRuntime: after init decorators") + + @property + def split_index(self): + """ + Returns the split index, caching the result after the first access. + """ + if self._split_index is None: + self._split_index = getattr(self._spin_task, "index", None) + + return self._split_index + + @property + def input_paths(self): + def _format_input_paths(task_pathspec, attempt): + _, run_id, step_name, task_id = task_pathspec.split("/") + return f"{run_id}/{step_name}/{task_id}/{attempt}" + + if self._input_paths: + return self._input_paths + + if self._step_func.name == "start": + from metaflow import Step + + flow_name, run_id, _, _ = self._spin_pathspec.split("/") + task = Step( + f"{flow_name}/{run_id}/_parameters", _namespace_check=False + ).task + self._input_paths = [ + _format_input_paths(task.pathspec, task.current_attempt) + ] + else: + parent_tasks = self._spin_task.parent_tasks + self._input_paths = [ + _format_input_paths(t.pathspec, t.current_attempt) for t in parent_tasks + ] + return self._input_paths + + @property + def whitelist_decorators(self): + if self._skip_decorators: + return [] + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + flow_datastore=self._flow_datastore, + flow=self._flow, + step=step, + run_id=self.run_id, + metadata=self._metadata, + environment=self._environment, + entrypoint=self._entrypoint, + event_logger=self._event_logger, + monitor=self._monitor, + input_paths=input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + **kwargs, + ) + + def execute(self): + exception = None + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + from_start("SpinRuntime: config values processed") + self.task = self._new_task(self._step_func.name, self.input_paths) + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) + + def _launch_and_monitor_task(self): + worker = Worker( + self.task, + self._max_log_size, + self._config_file_name, + orig_flow_datastore=self._orig_flow_datastore, + spin_pathspec=self._spin_pathspec, + whitelist_decorators=self.whitelist_decorators, + artifacts_module=self._artifacts_module, + persist=self._persist, + ) + from_start("SpinRuntime: created worker") + + poll = procpoll.make_poll() + fds = worker.fds() + for fd in fds: + poll.add(fd) + + active_fds = set(fds) + + while active_fds: + events = poll.poll(POLL_TIMEOUT) + for event in events: + if event.can_read: + worker.read_logline(event.fd) + if event.is_terminated: + poll.remove(event.fd) + active_fds.remove(event.fd) + from_start("SpinRuntime: read loglines") + returncode = worker.terminate() + from_start("SpinRuntime: worker terminated") + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + class NativeRuntime(object): def __init__( self, @@ -1769,8 +1992,27 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__( + self, + task, + orig_flow_datastore=None, + spin_pathspec=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, + ): self.task = task + if orig_flow_datastore is not None: + self.orig_flow_datastore = "%s@%s" % ( + orig_flow_datastore.TYPE, + orig_flow_datastore.datastore_root, + ) + else: + self.orig_flow_datastore = None + self.spin_pathspec = spin_pathspec + self.whitelist_decorators = whitelist_decorators + self.artifacts_module = artifacts_module + self.persist = persist self.entrypoint = list(task.entrypoint) step_obj = getattr(self.task.flow, self.task.step) self.top_level_options = { @@ -1808,19 +2050,48 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + if spin_pathspec: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, + "namespace": get_namespace() or "", + "ubf-context": self.task.ubf_context, + } + self.env = {} + + def spin_args(self): + self.commands = ["spin-step"] + self.command_args = [self.task.step] + + whitelist_decos = [deco.name for deco in self.whitelist_decorators] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, + "orig-flow-datastore": self.orig_flow_datastore, + "spin-pathspec": self.spin_pathspec, + "whitelist-decorators": compress_list(whitelist_decos), + "artifacts-module": self.artifacts_module, } + if self.persist: + self.command_options["persist"] = True self.env = {} def get_args(self): @@ -1861,9 +2132,24 @@ def __str__(self): class Worker(object): - def __init__(self, task, max_logs_size, config_file_name): + def __init__( + self, + task, + max_logs_size, + config_file_name, + orig_flow_datastore=None, + spin_pathspec=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, + ): self.task = task self._config_file_name = config_file_name + self._orig_flow_datastore = orig_flow_datastore + self._spin_pathspec = spin_pathspec + self._whitelist_decorators = whitelist_decorators + self._artifacts_module = artifacts_module + self._persist = persist self._proc = self._launch() if task.retries > task.user_code_retries: @@ -1895,7 +2181,14 @@ def __init__(self, task, max_logs_size, config_file_name): # not it is properly shut down) def _launch(self): - args = CLIArgs(self.task) + args = CLIArgs( + self.task, + orig_flow_datastore=self._orig_flow_datastore, + spin_pathspec=self._spin_pathspec, + whitelist_decorators=self._whitelist_decorators, + artifacts_module=self._artifacts_module, + persist=self._persist, + ) env = dict(os.environ) if self.task.clone_run_id: @@ -2050,13 +2343,14 @@ def terminate(self): else: self.emit_log(b"Task failed.", self._stderr, system_msg=True) else: - num = self.task.results["_foreach_num_splits"] - if num: - self.task.log( - "Foreach yields %d child steps." % num, - system_msg=True, - pid=self._proc.pid, - ) + if not self._spin_pathspec: + num = self.task.results["_foreach_num_splits"] + if num: + self.task.log( + "Foreach yields %d child steps." % num, + system_msg=True, + pid=self._proc.pid, + ) self.task.log( "Task finished successfully.", system_msg=True, pid=self._proc.pid ) diff --git a/metaflow/task.py b/metaflow/task.py index 4239832a386..a811bec485c 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -6,7 +6,6 @@ import time import traceback - from types import MethodType, FunctionType from metaflow.sidecar import Message, MessageTypes @@ -14,6 +13,7 @@ from .metaflow_config import MAX_ATTEMPTS from .metadata_provider import MetaDatum +from .metaflow_profile import from_start from .mflog import TASK_LOG_SOURCE from .datastore import Inputs, TaskDataStoreSet from .exception import ( @@ -49,6 +49,8 @@ def __init__( event_logger, monitor, ubf_context, + orig_flow_datastore=None, + spin_artifacts=None, ): self.flow = flow self.flow_datastore = flow_datastore @@ -58,6 +60,8 @@ def __init__( self.event_logger = event_logger self.monitor = monitor self.ubf_context = ubf_context + self.orig_flow_datastore = orig_flow_datastore + self.spin_artifacts = spin_artifacts def _exec_step_function(self, step_function, orig_step_func, input_obj=None): wrappers_stack = [] @@ -234,7 +238,6 @@ def property_setter( lambda _, parameter_ds=parameter_ds: parameter_ds["_graph_info"], ) all_vars.append("_graph_info") - if passdown: self.flow._datastore.passdown_partial(parameter_ds, all_vars) return param_only_vars @@ -262,6 +265,9 @@ def _init_data(self, run_id, join_type, input_paths): run_id, pathspecs=input_paths, prefetch_data_artifacts=prefetch_data_artifacts, + join_type=join_type, + orig_flow_datastore=self.orig_flow_datastore, + spin_artifacts=self.spin_artifacts, ) ds_list = [ds for ds in datastore_set] if len(ds_list) != len(input_paths): @@ -273,10 +279,27 @@ def _init_data(self, run_id, join_type, input_paths): # initialize directly in the single input case. ds_list = [] for input_path in input_paths: - run_id, step_name, task_id = input_path.split("/") + parts = input_path.split("/") + if len(parts) == 3: + run_id, step_name, task_id = parts + attempt = None + else: + run_id, step_name, task_id, attempt = parts + attempt = int(attempt) + ds_list.append( - self.flow_datastore.get_task_datastore(run_id, step_name, task_id) + self.flow_datastore.get_task_datastore( + run_id, + step_name, + task_id, + attempt=attempt, + join_type=join_type, + orig_flow_datastore=self.orig_flow_datastore, + spin_artifacts=self.spin_artifacts, + ) ) + from_start("MetaflowTask: got datastore for input path %s" % input_path) + if not ds_list: # this guards against errors in input paths raise MetaflowDataMissing( @@ -547,6 +570,8 @@ def run_step( split_index, retry_count, max_user_code_retries, + whitelist_decorators=None, + persist=True, ): if run_id and task_id: self.metadata.register_run_id(run_id) @@ -605,7 +630,12 @@ def run_step( step_func = getattr(self.flow, step_name) decorators = step_func.decorators - + if self.orig_flow_datastore: + # We filter only the whitelisted decorators in case of spin step. + decorators = [ + deco for deco in decorators if deco.name in whitelist_decorators + ] + from_start("MetaflowTask: decorators initialized") node = self.flow._graph[step_name] join_type = None if node.type == "join": @@ -613,17 +643,20 @@ def run_step( # 1. initialize output datastore output = self.flow_datastore.get_task_datastore( - run_id, step_name, task_id, attempt=retry_count, mode="w" + run_id, step_name, task_id, attempt=retry_count, mode="w", persist=persist ) output.init_task() + from_start("MetaflowTask: output datastore initialized") if input_paths: # 2. initialize input datastores inputs = self._init_data(run_id, join_type, input_paths) + from_start("MetaflowTask: input datastores initialized") # 3. initialize foreach state self._init_foreach(step_name, join_type, inputs, split_index) + from_start("MetaflowTask: foreach state initialized") # 4. initialize the iteration state is_recursive_step = ( @@ -682,7 +715,7 @@ def run_step( ), ] ) - + from_start("MetaflowTask: finished input processing") self.metadata.register_metadata( run_id, step_name, @@ -736,8 +769,11 @@ def run_step( "project_flow_name": current.get("project_flow_name"), "trace_id": trace_id or None, } + + from_start("MetaflowTask: task metadata initialized") start = time.time() self.metadata.start_task_heartbeat(self.flow.name, run_id, step_name, task_id) + from_start("MetaflowTask: heartbeat started") with self.monitor.measure("metaflow.task.duration"): try: with self.monitor.count("metaflow.task.start"): @@ -757,7 +793,6 @@ def run_step( # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. - if join_type: # Join step: @@ -816,6 +851,7 @@ def run_step( "graph_info": self.flow._graph_info, } ) + from_start("MetaflowTask: before pre-step decorators") for deco in decorators: deco.task_pre_step( step_name, @@ -846,12 +882,12 @@ def run_step( max_user_code_retries, self.ubf_context, ) - + from_start("MetaflowTask: finished decorator processing") if join_type: self._exec_step_function(step_func, orig_step_func, input_obj) else: self._exec_step_function(step_func, orig_step_func) - + from_start("MetaflowTask: step function executed") for deco in decorators: deco.task_post_step( step_name, @@ -894,6 +930,7 @@ def run_step( raise finally: + from_start("MetaflowTask: decorators finalized") if self.ubf_context == UBF_CONTROL: self._finalize_control_task() @@ -933,7 +970,7 @@ def run_step( ) output.save_metadata({"task_end": {}}) - + from_start("MetaflowTask: output persisted") # this writes a success marker indicating that the # "transaction" is done output.done() @@ -962,3 +999,4 @@ def run_step( name="duration", payload={**task_payload, "msg": str(duration)}, ) + from_start("MetaflowTask: task run completed") diff --git a/metaflow/util.py b/metaflow/util.py index 6742cc46a8a..ed34b8802e0 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -9,7 +9,8 @@ from functools import wraps from io import BytesIO from itertools import takewhile -from typing import Generator, List, Optional, Tuple +import re + try: # python2 diff --git a/test/unit/spin/artifacts/complex_dag_step_a.py b/test/unit/spin/artifacts/complex_dag_step_a.py new file mode 100644 index 00000000000..b7e81bf1b6f --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_a.py @@ -0,0 +1 @@ +ARTIFACTS = {"my_output": [10, 11, 12]} diff --git a/test/unit/spin/artifacts/complex_dag_step_d.py b/test/unit/spin/artifacts/complex_dag_step_d.py new file mode 100644 index 00000000000..5aa40d64766 --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_d.py @@ -0,0 +1,11 @@ +from metaflow import Run + + +def _get_artifact(): + task = Run("ComplexDAGFlow/2")["step_d"].task + task_pathspec = next(task.parent_task_pathspecs) + _, inp_path = task_pathspec.split("/", 1) + return {inp_path: {"my_output": [-1]}} + + +ARTIFACTS = _get_artifact() diff --git a/test/unit/spin/complex_dag_flow.py b/test/unit/spin/complex_dag_flow.py new file mode 100644 index 00000000000..04b185fe40f --- /dev/null +++ b/test/unit/spin/complex_dag_flow.py @@ -0,0 +1,116 @@ +from metaflow import FlowSpec, step, project, conda, Task, pypi + + +class ComplexDAGFlow(FlowSpec): + @step + def start(self): + self.split_start = [1, 2, 3] + self.my_output = [] + print("My output is: ", self.my_output) + self.next(self.step_a, foreach="split_start") + + @step + def step_a(self): + self.split_a = [4, 5] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_b, foreach="split_a") + + @step + def step_b(self): + self.split_b = [6, 7, 8] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_c, foreach="split_b") + + @conda(libraries={"numpy": "2.1.1"}) + @step + def step_c(self): + import numpy as np + + self.np_version = np.__version__ + print(f"numpy version: {self.np_version}") + self.my_output = self.my_output + [self.input] + [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_d) + + @step + def step_d(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_e) + + @step + def step_e(self): + print(f"I am step E. Input is: {self.input}") + self.split_e = [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_f, foreach="split_e") + + @step + def step_f(self): + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_g) + + @step + def step_g(self): + print("My output is: ", self.my_output) + self.next(self.step_h) + + @step + def step_h(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_i) + + @step + def step_i(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_j) + + @step + def step_j(self): + print("My output is: ", self.my_output) + self.next(self.step_k, self.step_l) + + @step + def step_k(self): + self.my_output = self.my_output + [11] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @step + def step_l(self): + print(f"I am step L. Input is: {self.input}") + self.my_output = self.my_output + [12] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @conda(libraries={"scikit-learn": "1.3.0"}) + @step + def step_m(self, inputs): + import sklearn + + self.sklearn_version = sklearn.__version__ + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("Sklearn version: ", self.sklearn_version) + print("My output is: ", self.my_output) + self.next(self.step_n) + + @step + def step_n(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.end) + + @step + def end(self): + self.my_output = self.my_output + [13] + print("My output is: ", self.my_output) + print("Flow is complete!") + + +if __name__ == "__main__": + ComplexDAGFlow() diff --git a/test/unit/spin/merge_artifacts_flow.py b/test/unit/spin/merge_artifacts_flow.py new file mode 100644 index 00000000000..59f1390e052 --- /dev/null +++ b/test/unit/spin/merge_artifacts_flow.py @@ -0,0 +1,63 @@ +from metaflow import FlowSpec, step + + +class MergeArtifactsFlow(FlowSpec): + + @step + def start(self): + self.pass_down = "a" + self.next(self.a, self.b) + + @step + def a(self): + self.common = 5 + self.x = 1 + self.y = 3 + self.from_a = 6 + self.next(self.join) + + @step + def b(self): + self.common = 5 + self.x = 2 + self.y = 4 + self.next(self.join) + + @step + def join(self, inputs): + print(f"In join step, self._datastore: {(type(self._datastore))}") + self.x = inputs.a.x + self.merge_artifacts(inputs, exclude=["y"]) + print("x is %s" % self.x) + print("pass_down is %s" % self.pass_down) + print("common is %d" % self.common) + print("from_a is %d" % self.from_a) + self.next(self.c) + + @step + def c(self): + self.next(self.d, self.e) + + @step + def d(self): + self.conflicting = 7 + self.next(self.join2) + + @step + def e(self): + self.conflicting = 8 + self.next(self.join2) + + @step + def join2(self, inputs): + self.merge_artifacts(inputs, include=["pass_down", "common"]) + print("Only pass_down and common exist here") + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + MergeArtifactsFlow() diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py new file mode 100644 index 00000000000..88941df578f --- /dev/null +++ b/test/unit/spin/test_spin.py @@ -0,0 +1,138 @@ +import pytest +from metaflow import Runner, Run + + +@pytest.fixture +def complex_dag_run(): + # with Runner('complex_dag_flow.py').run() as running: + # yield running.run + return Run("ComplexDAGFlow/5", _namespace_check=False) + + +@pytest.fixture +def merge_artifacts_run(): + # with Runner('merge_artifacts_flow.py').run() as running: + # yield running.run + return Run("MergeArtifactsFlow/55", _namespace_check=False) + + +def _assert_artifacts(task, spin_task): + spin_task_artifacts = { + artifact.id: artifact.data for artifact in spin_task.artifacts + } + print(f"Spin task artifacts: {spin_task_artifacts}") + for artifact in task.artifacts: + assert ( + artifact.id in spin_task_artifacts + ), f"Artifact {artifact.id} not found in spin task" + assert ( + artifact.data == spin_task_artifacts[artifact.id] + ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" + + +def _run_step(file_name, run, step_name, is_conda=False): + task = run[step_name].task + if not is_conda: + with Runner(file_name).spin(step_name, spin_pathspec=task.pathspec) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + _assert_artifacts(task, spin.task) + else: + with Runner(file_name, environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + ) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + print(f"Spin task artifacts: {spin.task.artifacts}") + _assert_artifacts(task, spin.task) + + +def test_complex_dag_flow(complex_dag_run): + print(f"Running test for ComplexDAGFlow flow: {complex_dag_run}") + for step in complex_dag_run.steps(): + print("-" * 100) + _run_step("complex_dag_flow.py", complex_dag_run, step.id, is_conda=True) + + +def test_merge_artifacts_flow(merge_artifacts_run): + print(f"Running test for merge artifacts flow: {merge_artifacts_run}") + for step in merge_artifacts_run.steps(): + print("-" * 100) + _run_step("merge_artifacts_flow.py", merge_artifacts_run, step.id) + + +def test_artifacts_module(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_a" + task = complex_dag_run[step_name].task + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + artifacts_module="./artifacts/complex_dag_step_a.py", + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + print(f"my_output: {spin_task['my_output']}") + assert spin_task["my_output"].data == [10, 11, 12, 3] + + +def test_artifacts_module_join_step(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_d" + task = complex_dag_run[step_name].task + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + artifacts_module="./artifacts/complex_dag_step_d.py", + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + assert spin_task["my_output"].data == [-1] + + +def test_skip_decorators(complex_dag_run): + print(f"Running test for skip decorator in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_m" + task = complex_dag_run[step_name].task + # Check if sklearn is available in the outer environment + # If not, this test will fail as it requires sklearn to be installed and skip_decorator + # is set to True + is_sklearn = True + try: + import sklearn + except ImportError: + is_sklearn = False + if is_sklearn: + # We verify that the sklearn version is the same as the one in the outside environment + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + skip_decorators=True, + ) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + spin_task = spin.task + import sklearn + + expected_version = sklearn.__version__ + assert ( + spin_task["sklearn_version"].data == expected_version + ), f"Expected sklearn version {expected_version} but got {spin_task['sklearn_version']}" + else: + # We assert that an exception is raised when trying to run the step with skip_decorators=True + with pytest.raises(Exception) as exc_info: + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + skip_decorators=True, + ): + pass From e14967ce1b5264fb6787dc79f0ab343c1d1fc479 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 20 Jul 2025 14:50:39 -0700 Subject: [PATCH 02/21] Update test suite --- metaflow/datastore/content_addressed_store.py | 1 - .../unit/spin/artifacts/complex_dag_step_d.py | 14 +--- test/unit/spin/conftest.py | 39 +++++++++++ .../unit/spin/{ => flows}/complex_dag_flow.py | 0 .../spin/{ => flows}/merge_artifacts_flow.py | 0 test/unit/spin/flows/simple_parameter_flow.py | 27 ++++++++ test/unit/spin/test_spin.py | 67 +++++++++++-------- 7 files changed, 108 insertions(+), 40 deletions(-) create mode 100644 test/unit/spin/conftest.py rename test/unit/spin/{ => flows}/complex_dag_flow.py (100%) rename test/unit/spin/{ => flows}/merge_artifacts_flow.py (100%) create mode 100644 test/unit/spin/flows/simple_parameter_flow.py diff --git a/metaflow/datastore/content_addressed_store.py b/metaflow/datastore/content_addressed_store.py index 75203174d9d..a8f2e0e4805 100644 --- a/metaflow/datastore/content_addressed_store.py +++ b/metaflow/datastore/content_addressed_store.py @@ -160,7 +160,6 @@ def load_blobs(self, keys, force_raw=False, _is_transfer=False): with self._storage_impl.load_bytes([p for _, p in load_paths]) as loaded: for path_key, file_path, meta in loaded: - print(f"path_key: {path_key}, file_path: {file_path}, meta: {meta}") key = self._storage_impl.path_split(path_key)[-1] # At this point, we either return the object as is (if raw) or # decode it according to the encoding version diff --git a/test/unit/spin/artifacts/complex_dag_step_d.py b/test/unit/spin/artifacts/complex_dag_step_d.py index 5aa40d64766..20bb0376e8d 100644 --- a/test/unit/spin/artifacts/complex_dag_step_d.py +++ b/test/unit/spin/artifacts/complex_dag_step_d.py @@ -1,11 +1,3 @@ -from metaflow import Run - - -def _get_artifact(): - task = Run("ComplexDAGFlow/2")["step_d"].task - task_pathspec = next(task.parent_task_pathspecs) - _, inp_path = task_pathspec.split("/", 1) - return {inp_path: {"my_output": [-1]}} - - -ARTIFACTS = _get_artifact() +# This file is kept for backwards compatibility but should not be used directly +# The artifacts are now generated dynamically via pytest fixtures +ARTIFACTS = {} diff --git a/test/unit/spin/conftest.py b/test/unit/spin/conftest.py new file mode 100644 index 00000000000..6c084b1c375 --- /dev/null +++ b/test/unit/spin/conftest.py @@ -0,0 +1,39 @@ +import pytest +from metaflow import Runner +import os + +# Get the directory containing the flows +FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") + + +@pytest.fixture(scope="session") +def complex_dag_run(): + """Run ComplexDAGFlow and return the completed run.""" + flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") + with Runner(flow_path, environment="conda").run() as running: + return running.run + + +@pytest.fixture(scope="session") +def merge_artifacts_run(): + """Run MergeArtifactsFlow and return the completed run.""" + flow_path = os.path.join(FLOWS_DIR, "merge_artifacts_flow.py") + with Runner(flow_path).run() as running: + return running.run + + +@pytest.fixture(scope="session") +def simple_parameter_run(): + """Run SimpleParameterFlow and return the completed run.""" + flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + with Runner(flow_path).run(alpha=0.05) as running: + return running.run + + +@pytest.fixture +def complex_dag_step_d_artifacts(complex_dag_run): + """Generate dynamic artifacts for complex_dag step_d tests.""" + task = complex_dag_run["step_d"].task + task_pathspec = next(task.parent_task_pathspecs) + _, inp_path = task_pathspec.split("/", 1) + return {inp_path: {"my_output": [-1]}} diff --git a/test/unit/spin/complex_dag_flow.py b/test/unit/spin/flows/complex_dag_flow.py similarity index 100% rename from test/unit/spin/complex_dag_flow.py rename to test/unit/spin/flows/complex_dag_flow.py diff --git a/test/unit/spin/merge_artifacts_flow.py b/test/unit/spin/flows/merge_artifacts_flow.py similarity index 100% rename from test/unit/spin/merge_artifacts_flow.py rename to test/unit/spin/flows/merge_artifacts_flow.py diff --git a/test/unit/spin/flows/simple_parameter_flow.py b/test/unit/spin/flows/simple_parameter_flow.py new file mode 100644 index 00000000000..b2d3410d2ad --- /dev/null +++ b/test/unit/spin/flows/simple_parameter_flow.py @@ -0,0 +1,27 @@ +from metaflow import FlowSpec, step, Parameter + + +class SimpleParameterFlow(FlowSpec): + alpha = Parameter("alpha", help="Learning rate", default=0.01) + + @step + def start(self): + print("SimpleParameterFlow is starting.") + print(f"Parameter alpha is set to: {self.alpha}") + self.a = 10 + self.b = 20 + self.next(self.end) + + @step + def end(self): + self.a = 50 + self.x = 100 + self.y = 200 + print("Parameter alpha in end step is: ", self.alpha) + del self.a + del self.x + print("SimpleParameterFlow is all done.") + + +if __name__ == "__main__": + SimpleParameterFlow() diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 88941df578f..63408bbeee1 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -1,19 +1,9 @@ import pytest -from metaflow import Runner, Run +from metaflow import Runner +import os - -@pytest.fixture -def complex_dag_run(): - # with Runner('complex_dag_flow.py').run() as running: - # yield running.run - return Run("ComplexDAGFlow/5", _namespace_check=False) - - -@pytest.fixture -def merge_artifacts_run(): - # with Runner('merge_artifacts_flow.py').run() as running: - # yield running.run - return Run("MergeArtifactsFlow/55", _namespace_check=False) +FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") +ARTIFACTS_DIR = os.path.join(os.path.dirname(__file__), "artifacts") def _assert_artifacts(task, spin_task): @@ -30,17 +20,19 @@ def _assert_artifacts(task, spin_task): ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" -def _run_step(file_name, run, step_name, is_conda=False): +def _run_step(flow_file, run, step_name, is_conda=False): task = run[step_name].task + flow_path = os.path.join(FLOWS_DIR, flow_file) + if not is_conda: - with Runner(file_name).spin(step_name, spin_pathspec=task.pathspec) as spin: + with Runner(flow_path).spin(step_name, spin_pathspec=task.pathspec) as spin: print("-" * 50) print( f"Running test for step: {step_name} with task pathspec: {task.pathspec}" ) _assert_artifacts(task, spin.task) else: - with Runner(file_name, environment="conda").spin( + with Runner(flow_path, environment="conda").spin( step_name, spin_pathspec=task.pathspec, ) as spin: @@ -66,14 +58,24 @@ def test_merge_artifacts_flow(merge_artifacts_run): _run_step("merge_artifacts_flow.py", merge_artifacts_run, step.id) +def test_simple_parameter_flow(simple_parameter_run): + print(f"Running test for SimpleParameterFlow: {simple_parameter_run}") + for step in simple_parameter_run.steps(): + print("-" * 100) + _run_step("simple_parameter_flow.py", simple_parameter_run, step.id) + + def test_artifacts_module(complex_dag_run): print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") step_name = "step_a" task = complex_dag_run[step_name].task - with Runner("complex_dag_flow.py", environment="conda").spin( + flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") + artifacts_path = os.path.join(ARTIFACTS_DIR, "complex_dag_step_a.py") + + with Runner(flow_path, environment="conda").spin( step_name, spin_pathspec=task.pathspec, - artifacts_module="./artifacts/complex_dag_step_a.py", + artifacts_module=artifacts_path, ) as spin: print("-" * 50) print(f"Running test for step: step_a with task pathspec: {task.pathspec}") @@ -82,17 +84,25 @@ def test_artifacts_module(complex_dag_run): assert spin_task["my_output"].data == [10, 11, 12, 3] -def test_artifacts_module_join_step(complex_dag_run): +def test_artifacts_module_join_step( + complex_dag_run, complex_dag_step_d_artifacts, tmp_path +): print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") step_name = "step_d" task = complex_dag_run[step_name].task - with Runner("complex_dag_flow.py", environment="conda").spin( + flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") + + # Create a temporary artifacts file with dynamic data + temp_artifacts_file = tmp_path / "temp_complex_dag_step_d.py" + temp_artifacts_file.write_text(f"ARTIFACTS = {repr(complex_dag_step_d_artifacts)}") + + with Runner(flow_path, environment="conda").spin( step_name, spin_pathspec=task.pathspec, - artifacts_module="./artifacts/complex_dag_step_d.py", + artifacts_module=str(temp_artifacts_file), ) as spin: print("-" * 50) - print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + print(f"Running test for step: step_d with task pathspec: {task.pathspec}") spin_task = spin.task assert spin_task["my_output"].data == [-1] @@ -101,17 +111,18 @@ def test_skip_decorators(complex_dag_run): print(f"Running test for skip decorator in ComplexDAGFlow: {complex_dag_run}") step_name = "step_m" task = complex_dag_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") + # Check if sklearn is available in the outer environment - # If not, this test will fail as it requires sklearn to be installed and skip_decorator - # is set to True is_sklearn = True try: import sklearn except ImportError: is_sklearn = False + if is_sklearn: # We verify that the sklearn version is the same as the one in the outside environment - with Runner("complex_dag_flow.py", environment="conda").spin( + with Runner(flow_path, environment="conda").spin( step_name, spin_pathspec=task.pathspec, skip_decorators=True, @@ -129,8 +140,8 @@ def test_skip_decorators(complex_dag_run): ), f"Expected sklearn version {expected_version} but got {spin_task['sklearn_version']}" else: # We assert that an exception is raised when trying to run the step with skip_decorators=True - with pytest.raises(Exception) as exc_info: - with Runner("complex_dag_flow.py", environment="conda").spin( + with pytest.raises(Exception): + with Runner(flow_path, environment="conda").spin( step_name, spin_pathspec=task.pathspec, skip_decorators=True, From 9f3e8c96feced576f96fdd49315915693775396c Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 20 Jul 2025 14:58:22 -0700 Subject: [PATCH 03/21] Run black --- metaflow/runtime.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index f54ec2789b7..a668dc1d6dd 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -28,7 +28,12 @@ from . import get_namespace from .client.filecache import FileCache, FileBlobCache, TaskMetadataCache from .metadata_provider import MetaDatum -from .metaflow_config import FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, MAX_ATTEMPTS, UI_URL, SPIN_ALLOWED_DECORATORS +from .metaflow_config import ( + FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, + MAX_ATTEMPTS, + UI_URL, + SPIN_ALLOWED_DECORATORS, +) from .metaflow_profile import from_start from .plugins import DATASTORES from .exception import ( From 824b926eb638ab9b0f4f0a48adc79428f13a25d1 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 20 Jul 2025 15:08:03 -0700 Subject: [PATCH 04/21] Update runtime dag test --- test/core/tests/runtime_dag.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/core/tests/runtime_dag.py b/test/core/tests/runtime_dag.py index 8bc54985d20..cfad22ee17e 100644 --- a/test/core/tests/runtime_dag.py +++ b/test/core/tests/runtime_dag.py @@ -71,7 +71,9 @@ def _equals_task(task1, task2): if name not in [ "parent_tasks", + "parent_task_pathspecs", "child_tasks", + "child_task_pathspecs", "metadata", "data", "artifacts", From 8f3c9dc5f568ea79ef12738f0e9546e6f03a821e Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 21 Aug 2025 11:48:11 -0700 Subject: [PATCH 05/21] Update skip decorator and datastore root logic --- metaflow/cli.py | 26 ++++++++--- metaflow/cli_components/run_cmds.py | 45 ++++++++++--------- metaflow/cli_components/step_cmd.py | 1 + metaflow/datastore/task_datastore.py | 1 - metaflow/decorators.py | 7 ++- metaflow/metaflow_config.py | 3 +- metaflow/plugins/datastores/local_storage.py | 29 +++++++++--- metaflow/runtime.py | 4 +- test/unit/spin/flows/simple_parameter_flow.py | 3 +- 9 files changed, 79 insertions(+), 40 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 5e5f95a0ab2..37f1823b7ee 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -25,8 +25,6 @@ DEFAULT_METADATA, DEFAULT_MONITOR, DEFAULT_PACKAGE_SUFFIXES, - DATASTORE_SYSROOT_SPIN, - DATASTORE_LOCAL_DIR, ) from .metaflow_current import current from .metaflow_profile import from_start @@ -42,10 +40,9 @@ ) from .pylint_wrapper import PyLint from .R import metaflow_r_version, use_r -from .util import get_latest_run_id, resolve_identity +from .util import get_latest_run_id, resolve_identity, decompress_list from .user_configs.config_options import LocalFileInput, config_options from .user_configs.config_parameters import ConfigValue -from .util import get_latest_run_id, resolve_identity ERASE_TO_EOL = "\033[K" HIGHLIGHT = "red" @@ -527,9 +524,12 @@ def start( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor ) ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] - # Set datastore_root to be DATASTORE_SYSROOT_SPIN if not provided - datastore_root = os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) + # Set a separate datastore root for spin + datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( + ctx.obj.echo, create_on_absent=True, use_spin_dir=True + ) ctx.obj.datastore_impl.datastore_root = datastore_root + ctx.obj.flow_datastore = FlowDataStore( ctx.obj.flow.name, ctx.obj.environment, # Same environment as run/resume @@ -550,6 +550,7 @@ def start( # It is important to initialize flow decorators early as some of the # things they provide may be used by some of the objects initialized after. + from_start(f"I am just above _init_flow_decorators") decorators._init_flow_decorators( ctx.obj.flow, ctx.obj.graph, @@ -593,6 +594,15 @@ def start( ): # run/resume are special cases because they can add more decorators with --with, # so they have to take care of themselves. + whitelist_decorators = None + if "--whitelist-decorators" in ctx.saved_args: + # If whitelist-decorators is specified, we only will run the decorators hooks + # for the decorators that are whitelisted. + idx = ctx.saved_args.index("--whitelist-decorators") + whitelist_decorators = ctx.saved_args[idx + 1] + whitelist_decorators = ( + decompress_list(whitelist_decorators) if whitelist_decorators else [] + ) all_decospecs = ctx.obj.tl_decospecs + list( ctx.obj.environment.decospecs() or [] @@ -603,6 +613,9 @@ def start( # or a scheduler setting them up in their own way. if ctx.saved_args[0] not in ("step", "init"): all_decospecs += DEFAULT_DECOSPECS.split() + elif ctx.saved_args[0] == "spin-step": + # If we are in spin-args, we will not attach any decorators + all_decospecs = [] if all_decospecs: decorators._attach_decorators(ctx.obj.flow, all_decospecs) decorators._init(ctx.obj.flow) @@ -616,6 +629,7 @@ def start( ctx.obj.environment, ctx.obj.flow_datastore, ctx.obj.logger, + whitelist_decorators, ) # Check the graph again (mutators may have changed it) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 91a06e0137a..1a0b60d6cc9 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -20,7 +20,7 @@ from ..util import get_latest_run_id, write_latest_run_id -def before_run(obj, tags, decospecs): +def before_run(obj, tags, decospecs, skip_decorators=False): validate_tags(tags) # There's a --with option both at the top-level and for the run/resume/spin @@ -37,26 +37,27 @@ def before_run(obj, tags, decospecs): # - run level decospecs # - top level decospecs # - environment decospecs - all_decospecs = ( - list(decospecs or []) - + obj.tl_decospecs - + list(obj.environment.decospecs() or []) - ) - if all_decospecs: - # These decospecs are the ones from run/resume/spin PLUS the ones from the - # environment (for example the @conda) - decorators._attach_decorators(obj.flow, all_decospecs) - decorators._init(obj.flow) - # Regenerate graph if we attached more decorators - obj.flow.__class__._init_attrs() - obj.graph = obj.flow._graph - - obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) - # obj.environment.init_environment(obj.logger) - - decorators._init_step_decorators( - obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger - ) + if not skip_decorators: + all_decospecs = ( + list(decospecs or []) + + obj.tl_decospecs + + list(obj.environment.decospecs() or []) + ) + if all_decospecs: + # These decospecs are the ones from run/resume/spin PLUS the ones from the + # environment (for example the @conda) + decorators._attach_decorators(obj.flow, all_decospecs) + decorators._init(obj.flow) + # Regenerate graph if we attached more decorators + obj.flow.__class__._init_attrs() + obj.graph = obj.flow._graph + + obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) + # obj.environment.init_environment(obj.logger) + + decorators._init_step_decorators( + obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger + ) # Re-read graph since it may have been modified by mutators obj.graph = obj.flow._graph @@ -471,7 +472,7 @@ def spin( runner_attribute_file=None, **kwargs, ): - before_run(obj, [], []) + before_run(obj, [], [], skip_decorators) obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") obj.flow._set_constants(obj.graph, kwargs, obj.config_options) step_func = getattr(obj.flow, step_name, None) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 309a082ee68..ca244ff6bfa 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -272,6 +272,7 @@ def spin_step( artifacts_module=None, persist=True, ): + from_start("I am in spin step") import time if ctx.obj.is_quiet: diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 7ec20825b7b..11cf3904f98 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -783,7 +783,6 @@ def persist(self, flow): if hasattr(flow._datastore, "orig_datastore"): parent_artifacts = set(flow._datastore._objects.keys()) unchanged_artifacts = parent_artifacts - current_artifact_names - print(f"Transferring unchanged artifacts: {unchanged_artifacts}") if unchanged_artifacts: self.transfer_artifacts( flow._datastore.orig_datastore, names=list(unchanged_artifacts) diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 583f8a4515e..18160642156 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -714,7 +714,9 @@ def _init_flow_decorators( ) -def _init_step_decorators(flow, graph, environment, flow_datastore, logger): +def _init_step_decorators( + flow, graph, environment, flow_datastore, logger, whitelist_decorators=None +): # NOTE: We don't need graph but keeping it for backwards compatibility with # extensions that use it directly. We will remove it at some point. @@ -785,6 +787,9 @@ def _init_step_decorators(flow, graph, environment, flow_datastore, logger): for step in flow: for deco in step.decorators: + # We skip decorators that are not in the whitelist + if not whitelist_decorators and deco.name not in whitelist_decorators: + continue deco.step_init( flow, graph, diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index cde0b5b37f9..a0b831d9877 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -21,6 +21,7 @@ # Path to the local directory to store artifacts for 'local' datastore. DATASTORE_LOCAL_DIR = ".metaflow" +DATASTORE_SPIN_LOCAL_DIR = ".spin_metaflow" # Local configuration file (in .metaflow) containing overrides per-project LOCAL_CONFIG_FILE = "config.json" @@ -64,7 +65,7 @@ # Datastore configuration ### DATASTORE_SYSROOT_LOCAL = from_conf("DATASTORE_SYSROOT_LOCAL") -DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN", "/tmp/metaflow") +DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN", "/tmp") # S3 bucket and prefix to store artifacts for 's3' datastore. DATASTORE_SYSROOT_S3 = from_conf("DATASTORE_SYSROOT_S3") # Azure Blob Storage container and blob prefix diff --git a/metaflow/plugins/datastores/local_storage.py b/metaflow/plugins/datastores/local_storage.py index 4077a9404dd..8aa0e1741e5 100644 --- a/metaflow/plugins/datastores/local_storage.py +++ b/metaflow/plugins/datastores/local_storage.py @@ -1,7 +1,12 @@ import json import os -from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, DATASTORE_SYSROOT_LOCAL +from metaflow.metaflow_config import ( + DATASTORE_LOCAL_DIR, + DATASTORE_SYSROOT_LOCAL, + DATASTORE_SPIN_LOCAL_DIR, + DATASTORE_SYSROOT_SPIN, +) from metaflow.datastore.datastore_storage import CloseAfterUse, DataStoreStorage @@ -10,15 +15,24 @@ class LocalStorage(DataStoreStorage): METADATA_DIR = "_meta" @classmethod - def get_datastore_root_from_config(cls, echo, create_on_absent=True): - result = DATASTORE_SYSROOT_LOCAL + def get_datastore_root_from_config( + cls, echo, create_on_absent=True, use_spin_dir=False + ): + if use_spin_dir: + datastore_dir = DATASTORE_SPIN_LOCAL_DIR + sysroot_var = DATASTORE_SYSROOT_LOCAL + else: + datastore_dir = DATASTORE_LOCAL_DIR # ".metaflow" + sysroot_var = DATASTORE_SYSROOT_LOCAL + + result = sysroot_var if result is None: try: # Python2 current_path = os.getcwdu() except: # noqa E722 current_path = os.getcwd() - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) + check_dir = os.path.join(current_path, datastore_dir) check_dir = os.path.realpath(check_dir) orig_path = check_dir top_level_reached = False @@ -28,12 +42,13 @@ def get_datastore_root_from_config(cls, echo, create_on_absent=True): top_level_reached = True break # We are no longer making upward progress current_path = new_path - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) + check_dir = os.path.join(current_path, datastore_dir) if top_level_reached: if create_on_absent: # Could not find any directory to use so create a new one + dir_type = "spin datastore" if use_spin_dir else "local datastore" echo( - "Creating local datastore in current directory (%s)" % orig_path + "Creating %s in current directory (%s)" % (dir_type, orig_path) ) os.mkdir(orig_path) result = orig_path @@ -42,7 +57,7 @@ def get_datastore_root_from_config(cls, echo, create_on_absent=True): else: result = check_dir else: - result = os.path.join(result, DATASTORE_LOCAL_DIR) + result = os.path.join(result, datastore_dir) return result @staticmethod diff --git a/metaflow/runtime.py b/metaflow/runtime.py index a668dc1d6dd..e6e659b2559 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -226,7 +226,8 @@ def _format_input_paths(task_pathspec, attempt): @property def whitelist_decorators(self): if self._skip_decorators: - return [] + self._whitelist_decorators = [] + return self._whitelist_decorators if self._whitelist_decorators: return self._whitelist_decorators self._whitelist_decorators = [ @@ -2226,6 +2227,7 @@ def _launch(self): # by read_logline() below that relies on readline() not blocking # print('running', args) cmdline = args.get_args() + from_start(f"Command line: {' '.join(cmdline)}") debug.subcommand_exec(cmdline) return subprocess.Popen( cmdline, diff --git a/test/unit/spin/flows/simple_parameter_flow.py b/test/unit/spin/flows/simple_parameter_flow.py index b2d3410d2ad..0a36ff007ce 100644 --- a/test/unit/spin/flows/simple_parameter_flow.py +++ b/test/unit/spin/flows/simple_parameter_flow.py @@ -1,9 +1,10 @@ -from metaflow import FlowSpec, step, Parameter +from metaflow import FlowSpec, step, Parameter, titus class SimpleParameterFlow(FlowSpec): alpha = Parameter("alpha", help="Learning rate", default=0.01) + @titus @step def start(self): print("SimpleParameterFlow is starting.") From 56cc2e95d710afbf6d3e8f4612e395ce44980985 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Sun, 24 Aug 2025 13:55:56 -0700 Subject: [PATCH 06/21] Create new spin datastore, metadata and support spin with local metadata --- .gitignore | 2 + metaflow/cli.py | 40 ++-- metaflow/cli_components/run_cmds.py | 34 ++-- metaflow/cli_components/step_cmd.py | 16 +- metaflow/client/core.py | 1 + metaflow/decorators.py | 79 +++++++- metaflow/metaflow_config.py | 15 +- metaflow/plugins/__init__.py | 2 + metaflow/plugins/cards/card_datastore.py | 12 +- metaflow/plugins/datastores/local_storage.py | 27 +-- metaflow/plugins/datastores/spin_storage.py | 12 ++ metaflow/plugins/metadata_providers/local.py | 158 +++++++-------- metaflow/plugins/metadata_providers/spin.py | 16 ++ metaflow/runner/metaflow_runner.py | 48 +++-- metaflow/runtime.py | 91 +++++++-- metaflow/task.py | 10 +- metaflow/util.py | 148 +++++++++++++- test/unit/spin/conftest.py | 83 ++++++-- test/unit/spin/flows/merge_artifacts_flow.py | 1 - test/unit/spin/flows/myconfig.json | 1 + test/unit/spin/flows/simple_card_flow.py | 27 +++ test/unit/spin/flows/simple_config_flow.py | 22 ++ test/unit/spin/flows/simple_parameter_flow.py | 9 +- test/unit/spin/spin_test_helpers.py | 32 +++ test/unit/spin/test_spin.py | 188 ++++++++---------- 25 files changed, 754 insertions(+), 320 deletions(-) create mode 100644 metaflow/plugins/datastores/spin_storage.py create mode 100644 metaflow/plugins/metadata_providers/spin.py create mode 100644 test/unit/spin/flows/myconfig.json create mode 100644 test/unit/spin/flows/simple_card_flow.py create mode 100644 test/unit/spin/flows/simple_config_flow.py create mode 100644 test/unit/spin/spin_test_helpers.py diff --git a/.gitignore b/.gitignore index b36301c7c44..3557ef49ad6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ __pycache__/ *.py[cod] *$py.class *.metaflow +*.spin_metaflow +metaflow_card_cache/ build/ dist/ diff --git a/metaflow/cli.py b/metaflow/cli.py index 37f1823b7ee..fc9e853ad4f 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -477,19 +477,12 @@ def start( # set force rebuild flag for environments that support it. ctx.obj.environment._force_rebuild = force_rebuild_environments ctx.obj.environment.validate_environment(ctx.obj.logger, datastore) - ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) - ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) - ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor ) @@ -504,29 +497,33 @@ def start( ctx.obj.config_options = config_options ctx.obj.is_spin = False + ctx.obj.skip_decorators = False # Override values for spin if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: - # To minimize side-effects for spin, we will only use the following: + # To minimize side effects for spin, we will only use the following: # - local metadata provider, # - local datastore, # - local environment, # - null event logger, # - null monitor ctx.obj.is_spin = True + if "--skip-decorators" in ctx.saved_args: + ctx.obj.skip_decorators = True + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( flow=ctx.obj.flow, env=ctx.obj.environment ) ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( + # Use spin metadata, spin datastore, and spin datastore root + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor ) - ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] - # Set a separate datastore root for spin + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "spin"][0] datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( - ctx.obj.echo, create_on_absent=True, use_spin_dir=True + ctx.obj.echo, create_on_absent=True ) ctx.obj.datastore_impl.datastore_root = datastore_root @@ -550,7 +547,6 @@ def start( # It is important to initialize flow decorators early as some of the # things they provide may be used by some of the objects initialized after. - from_start(f"I am just above _init_flow_decorators") decorators._init_flow_decorators( ctx.obj.flow, ctx.obj.graph, @@ -560,6 +556,8 @@ def start( ctx.obj.logger, echo, deco_options, + ctx.obj.is_spin, + ctx.obj.skip_decorators, ) # In the case of run/resume/spin, we will want to apply the TL decospecs @@ -592,18 +590,8 @@ def start( and ctx.saved_args and ctx.saved_args[0] not in ("run", "resume", "spin") ): - # run/resume are special cases because they can add more decorators with --with, + # run/resume/spin are special cases because they can add more decorators with --with, # so they have to take care of themselves. - whitelist_decorators = None - if "--whitelist-decorators" in ctx.saved_args: - # If whitelist-decorators is specified, we only will run the decorators hooks - # for the decorators that are whitelisted. - idx = ctx.saved_args.index("--whitelist-decorators") - whitelist_decorators = ctx.saved_args[idx + 1] - whitelist_decorators = ( - decompress_list(whitelist_decorators) if whitelist_decorators else [] - ) - all_decospecs = ctx.obj.tl_decospecs + list( ctx.obj.environment.decospecs() or [] ) @@ -629,7 +617,9 @@ def start( ctx.obj.environment, ctx.obj.flow_datastore, ctx.obj.logger, - whitelist_decorators, + # The last two arguments are only used for spin steps + ctx.obj.is_spin, + ctx.obj.skip_decorators, ) # Check the graph again (mutators may have changed it) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 1a0b60d6cc9..c957acc22a4 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -17,7 +17,7 @@ # from ..client.core import Run from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id +from ..util import get_latest_run_id, write_latest_run_id, parse_spin_pathspec def before_run(obj, tags, decospecs, skip_decorators=False): @@ -37,6 +37,9 @@ def before_run(obj, tags, decospecs, skip_decorators=False): # - run level decospecs # - top level decospecs # - environment decospecs + from_start( + f"Inside before_run, skip_decorators={skip_decorators}, is_spin={obj.is_spin}" + ) if not skip_decorators: all_decospecs = ( list(decospecs or []) @@ -56,7 +59,13 @@ def before_run(obj, tags, decospecs, skip_decorators=False): # obj.environment.init_environment(obj.logger) decorators._init_step_decorators( - obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger + obj.flow, + obj.graph, + obj.environment, + obj.flow_datastore, + obj.logger, + obj.is_spin, + skip_decorators, ) # Re-read graph since it may have been modified by mutators obj.graph = obj.flow._graph @@ -418,19 +427,14 @@ def run( @parameters.add_custom_parameters(deploy_mode=True) @click.command(help="Spins up a task for a given step from a previous run locally.") -@click.argument("step-name") -@click.option( - "--spin-pathspec", - default=None, - type=str, - help="Use specified task pathspec from a previous run to spin up the step.", -) +@tracing.cli("cli/spin") +@click.argument("pathspec") @click.option( "--skip-decorators/--no-skip-decorators", is_flag=True, default=False, show_default=True, - help="Skip decorators attached to the step.", + help="Skip decorators attached to the step or flow.", ) @click.option( "--artifacts-module", @@ -462,8 +466,7 @@ def run( @click.pass_obj def spin( obj, - step_name, - spin_pathspec=None, + pathspec, persist=True, artifacts_module=None, skip_decorators=False, @@ -472,6 +475,9 @@ def spin( runner_attribute_file=None, **kwargs, ): + # Parse the pathspec argument to extract step name and full pathspec + step_name, parsed_pathspec = parse_spin_pathspec(pathspec, obj.flow.name) + before_run(obj, [], [], skip_decorators) obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") obj.flow._set_constants(obj.graph, kwargs, obj.config_options) @@ -495,7 +501,7 @@ def spin( obj.monitor, step_func, step_name, - spin_pathspec, + parsed_pathspec, skip_decorators, artifacts_module, persist, @@ -503,7 +509,6 @@ def spin( ) write_latest_run_id(obj, spin_runtime.run_id) write_file(run_id_file, spin_runtime.run_id) - # datastore_root is os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) # We only need the root for the metadata, i.e. the portion before DATASTORE_LOCAL_DIR datastore_root = spin_runtime._flow_datastore._storage_impl.datastore_root orig_task_metadata_root = datastore_root.rsplit("/", 1)[0] @@ -521,6 +526,7 @@ def spin( "flow_name": obj.flow.name, # Store metadata in a format that can be used by the Runner API "metadata": f"{obj.metadata.__class__.TYPE}@{orig_task_metadata_root}", + # "metadata": f"spin@{orig_task_metadata_root}", }, f, ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index ca244ff6bfa..0a36792dca0 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -6,6 +6,7 @@ from ..datastore.flow_datastore import FlowDataStore from ..exception import CommandException from ..client.filecache import FileCache, FileBlobCache, TaskMetadataCache +from ..metaflow_config import SPIN_ALLOWED_DECORATORS from ..metaflow_profile import from_start from ..plugins import DATASTORES from ..task import MetaflowTask @@ -235,8 +236,11 @@ def step( help="Change namespace from the default (your username) to the specified tag.", ) @click.option( - "--whitelist-decorators", - help="A comma-separated list of whitelisted decorators to use for the spin step", + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step or flow.", ) @click.option( "--persist/--no-persist", @@ -268,11 +272,10 @@ def spin_step( retry_count=None, max_user_code_retries=None, opt_namespace=None, - whitelist_decorators=None, + skip_decorators=False, artifacts_module=None, persist=True, ): - from_start("I am in spin step") import time if ctx.obj.is_quiet: @@ -285,9 +288,8 @@ def spin_step( input_paths = decompress_list(input_paths) if input_paths else [] - whitelist_decorators = ( - decompress_list(whitelist_decorators) if whitelist_decorators else [] - ) + skip_decorators = skip_decorators + whitelist_decorators = [] if skip_decorators else SPIN_ALLOWED_DECORATORS from_start("SpinStep: initialized decorators") spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} from_start("SpinStep: read artifacts module") diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 6a8f6f3af19..58aecb12afa 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -277,6 +277,7 @@ def __init__( self._attempt = attempt self._current_namespace = _current_namespace or get_namespace() self._namespace_check = _namespace_check + # If the current namespace is False, we disable checking for namespace for this # and all children objects. Not setting namespace_check to False has the consequence # of preventing access to children objects after the namespace changes diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 18160642156..267d875f684 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -27,7 +27,7 @@ UserStepDecoratorBase, UserStepDecoratorMeta, ) - +from .metaflow_config import SPIN_ALLOWED_DECORATORS from metaflow._vendor import click @@ -658,6 +658,53 @@ def _attach_decorators_to_step(step, decospecs): step_deco.add_or_raise(step, False, 1, None) +def _should_skip_decorator_for_spin( + deco, is_spin, skip_decorators, logger, decorator_type="decorator" +): + """ + Determine if a decorator should be skipped for spin steps. + + Parameters: + ----------- + deco : Decorator + The decorator instance to check + is_spin : bool + Whether this is a spin step + skip_decorators : bool + Whether to skip all decorators + logger : callable + Logger function for warnings + decorator_type : str + Type of decorator ("Flow decorator" or "Step decorator") for logging + + Returns: + -------- + bool + True if the decorator should be skipped, False otherwise + """ + if not is_spin: + return False + + # Skip all decorator hooks if skip_decorators is True + if skip_decorators: + return True + + # Run decorator hooks for spin steps only if they are in the whitelist + if deco.name not in SPIN_ALLOWED_DECORATORS: + logger( + f"[Warning] {decorator_type} '{deco.name}' is not supported in spin steps. " + f"Supported decorators are: [{', '.join(SPIN_ALLOWED_DECORATORS)}]. " + f"Skipping this decorator as it is not in the whitelist.\n" + f"Alternatively, you can use the --skip-decorators flag to skip running all decorators in spin steps.", + system_msg=True, + timestamp=False, + bad=True, + ) + return True + + return False + + def _init(flow, only_non_static=False): for decorators in flow._flow_decorators.values(): for deco in decorators: @@ -673,7 +720,16 @@ def _init(flow, only_non_static=False): def _init_flow_decorators( - flow, graph, environment, flow_datastore, metadata, logger, echo, deco_options + flow, + graph, + environment, + flow_datastore, + metadata, + logger, + echo, + deco_options, + is_spin=False, + skip_decorators=False, ): # Since all flow decorators are stored as `{key:[deco]}` we iterate through each of them. for decorators in flow._flow_decorators.values(): @@ -702,6 +758,10 @@ def _init_flow_decorators( for option, option_info in deco.options.items() } for deco in decorators: + if _should_skip_decorator_for_spin( + deco, is_spin, skip_decorators, logger, "Flow decorator" + ): + continue deco.flow_init( flow, graph, @@ -715,9 +775,15 @@ def _init_flow_decorators( def _init_step_decorators( - flow, graph, environment, flow_datastore, logger, whitelist_decorators=None + flow, + graph, + environment, + flow_datastore, + logger, + is_spin=False, + skip_decorators=False, ): - # NOTE: We don't need graph but keeping it for backwards compatibility with + # NOTE: We don't need the graph but keeping it for backwards compatibility with # extensions that use it directly. We will remove it at some point. # We call the mutate method for both the flow and step mutators. @@ -787,8 +853,9 @@ def _init_step_decorators( for step in flow: for deco in step.decorators: - # We skip decorators that are not in the whitelist - if not whitelist_decorators and deco.name not in whitelist_decorators: + if _should_skip_decorator_for_spin( + deco, is_spin, skip_decorators, logger, "Step decorator" + ): continue deco.step_init( flow, diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index a0b831d9877..0dc4526833c 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -52,7 +52,18 @@ # Spin configuration ### SPIN_ALLOWED_DECORATORS = from_conf( - "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] + "SPIN_ALLOWED_DECORATORS", + [ + "conda", + "pypi", + "conda_base", + "pypi_base", + "environment", + "project", + "timeout", + "conda_env_internal", + "card", + ], ) ### @@ -65,7 +76,7 @@ # Datastore configuration ### DATASTORE_SYSROOT_LOCAL = from_conf("DATASTORE_SYSROOT_LOCAL") -DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN", "/tmp") +DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN") # S3 bucket and prefix to store artifacts for 's3' datastore. DATASTORE_SYSROOT_S3 = from_conf("DATASTORE_SYSROOT_S3") # Azure Blob Storage container and blob prefix diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index c70e532c093..5758ed7406a 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -83,11 +83,13 @@ METADATA_PROVIDERS_DESC = [ ("service", ".metadata_providers.service.ServiceMetadataProvider"), ("local", ".metadata_providers.local.LocalMetadataProvider"), + ("spin", ".metadata_providers.spin.SpinMetadataProvider"), ] # Add datastore here DATASTORES_DESC = [ ("local", ".datastores.local_storage.LocalStorage"), + ("spin", ".datastores.spin_storage.SpinStorage"), ("s3", ".datastores.s3_storage.S3Storage"), ("azure", ".datastores.azure_storage.AzureStorage"), ("gs", ".datastores.gs_storage.GSStorage"), diff --git a/metaflow/plugins/cards/card_datastore.py b/metaflow/plugins/cards/card_datastore.py index 18ce52ca463..431166f5278 100644 --- a/metaflow/plugins/cards/card_datastore.py +++ b/metaflow/plugins/cards/card_datastore.py @@ -9,6 +9,7 @@ CARD_S3ROOT, CARD_LOCALROOT, DATASTORE_LOCAL_DIR, + DATASTORE_SPIN_LOCAL_DIR, CARD_SUFFIX, CARD_AZUREROOT, CARD_GSROOT, @@ -58,12 +59,17 @@ def get_storage_root(cls, storage_type): return CARD_AZUREROOT elif storage_type == "gs": return CARD_GSROOT - elif storage_type == "local": + elif storage_type == "local" or storage_type == "spin": # Borrowing some of the logic from LocalStorage.get_storage_root result = CARD_LOCALROOT + local_dir = ( + DATASTORE_SPIN_LOCAL_DIR + if storage_type == "spin" + else DATASTORE_LOCAL_DIR + ) if result is None: current_path = os.getcwd() - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) + check_dir = os.path.join(current_path, local_dir) check_dir = os.path.realpath(check_dir) orig_path = check_dir while not os.path.isdir(check_dir): @@ -73,7 +79,7 @@ def get_storage_root(cls, storage_type): # return the top level path return os.path.join(orig_path, CARD_SUFFIX) current_path = new_path - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) + check_dir = os.path.join(current_path, local_dir) return os.path.join(check_dir, CARD_SUFFIX) else: # Let's make it obvious we need to update this block for each new datastore backend... diff --git a/metaflow/plugins/datastores/local_storage.py b/metaflow/plugins/datastores/local_storage.py index 8aa0e1741e5..bb4791df8d3 100644 --- a/metaflow/plugins/datastores/local_storage.py +++ b/metaflow/plugins/datastores/local_storage.py @@ -4,8 +4,6 @@ from metaflow.metaflow_config import ( DATASTORE_LOCAL_DIR, DATASTORE_SYSROOT_LOCAL, - DATASTORE_SPIN_LOCAL_DIR, - DATASTORE_SYSROOT_SPIN, ) from metaflow.datastore.datastore_storage import CloseAfterUse, DataStoreStorage @@ -13,26 +11,19 @@ class LocalStorage(DataStoreStorage): TYPE = "local" METADATA_DIR = "_meta" + DATASTORE_DIR = DATASTORE_LOCAL_DIR # ".metaflow" + SYSROOT_VAR = DATASTORE_SYSROOT_LOCAL @classmethod - def get_datastore_root_from_config( - cls, echo, create_on_absent=True, use_spin_dir=False - ): - if use_spin_dir: - datastore_dir = DATASTORE_SPIN_LOCAL_DIR - sysroot_var = DATASTORE_SYSROOT_LOCAL - else: - datastore_dir = DATASTORE_LOCAL_DIR # ".metaflow" - sysroot_var = DATASTORE_SYSROOT_LOCAL - - result = sysroot_var + def get_datastore_root_from_config(cls, echo, create_on_absent=True): + result = cls.SYSROOT_VAR if result is None: try: # Python2 current_path = os.getcwdu() except: # noqa E722 current_path = os.getcwd() - check_dir = os.path.join(current_path, datastore_dir) + check_dir = os.path.join(current_path, cls.DATASTORE_DIR) check_dir = os.path.realpath(check_dir) orig_path = check_dir top_level_reached = False @@ -42,13 +33,13 @@ def get_datastore_root_from_config( top_level_reached = True break # We are no longer making upward progress current_path = new_path - check_dir = os.path.join(current_path, datastore_dir) + check_dir = os.path.join(current_path, cls.DATASTORE_DIR) if top_level_reached: if create_on_absent: # Could not find any directory to use so create a new one - dir_type = "spin datastore" if use_spin_dir else "local datastore" echo( - "Creating %s in current directory (%s)" % (dir_type, orig_path) + "Creating %s datastore in current directory (%s)" + % (cls.TYPE, orig_path) ) os.mkdir(orig_path) result = orig_path @@ -57,7 +48,7 @@ def get_datastore_root_from_config( else: result = check_dir else: - result = os.path.join(result, datastore_dir) + result = os.path.join(result, cls.DATASTORE_DIR) return result @staticmethod diff --git a/metaflow/plugins/datastores/spin_storage.py b/metaflow/plugins/datastores/spin_storage.py new file mode 100644 index 00000000000..29fd5bdcc30 --- /dev/null +++ b/metaflow/plugins/datastores/spin_storage.py @@ -0,0 +1,12 @@ +from metaflow.metaflow_config import ( + DATASTORE_SPIN_LOCAL_DIR, + DATASTORE_SYSROOT_SPIN, +) +from metaflow.plugins.datastores.local_storage import LocalStorage + + +class SpinStorage(LocalStorage): + TYPE = "spin" + METADATA_DIR = "_meta" + DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".spin_metaflow" + SYSROOT_VAR = DATASTORE_SYSROOT_SPIN diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 74de40a61e8..424812c810f 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -18,6 +18,14 @@ class LocalMetadataProvider(MetadataProvider): TYPE = "local" + DATASTORE_DIR = DATASTORE_LOCAL_DIR # ".metaflow" + + @classmethod + def _get_storage_class(cls): + # This method is meant to be overridden + from metaflow.plugins.datastores.local_storage import LocalStorage + + return LocalStorage def __init__(self, environment, flow, event_logger, monitor): super(LocalMetadataProvider, self).__init__( @@ -26,30 +34,28 @@ def __init__(self, environment, flow, event_logger, monitor): @classmethod def compute_info(cls, val): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() - v = os.path.realpath(os.path.join(val, DATASTORE_LOCAL_DIR)) + v = os.path.realpath(os.path.join(val, cls.DATASTORE_DIR)) if os.path.isdir(v): - LocalStorage.datastore_root = v + storage_class.datastore_root = v return val raise ValueError( - "Could not find directory %s in directory %s" % (DATASTORE_LOCAL_DIR, val) + "Could not find directory %s in directory %s" % (cls.DATASTORE_DIR, val) ) @classmethod def default_info(cls): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() def print_clean(line, **kwargs): print(line) - v = LocalStorage.get_datastore_root_from_config( + v = storage_class.get_datastore_root_from_config( print_clean, create_on_absent=False ) if v is None: - return ( - "" % DATASTORE_LOCAL_DIR - ) + return "" % cls.DATASTORE_DIR return os.path.dirname(v) def version(self): @@ -102,7 +108,7 @@ def register_task_id( def register_data_artifacts( self, run_id, step_name, task_id, attempt_id, artifacts ): - meta_dir = self._create_and_get_metadir( + meta_dir = self.__class__._create_and_get_metadir( self._flow_name, run_id, step_name, task_id ) artlist = self._artifacts_to_json( @@ -112,7 +118,7 @@ def register_data_artifacts( self._save_meta(meta_dir, artdict) def register_metadata(self, run_id, step_name, task_id, metadata): - meta_dir = self._create_and_get_metadir( + meta_dir = self.__class__._create_and_get_metadir( self._flow_name, run_id, step_name, task_id ) metalist = self._metadata_to_json(run_id, step_name, task_id, metadata) @@ -132,9 +138,7 @@ def _mutate_user_tags_for_run( def _optimistically_mutate(): # get existing tags - run = LocalMetadataProvider.get_object( - "run", "self", {}, None, flow_id, run_id - ) + run = cls.get_object("run", "self", {}, None, flow_id, run_id) if not run: raise MetaflowTaggingError( msg="Run not found (%s, %s)" % (flow_id, run_id) @@ -167,15 +171,13 @@ def _optimistically_mutate(): validate_tags(next_user_tags_set, existing_tags=existing_user_tag_set) # write new tag set to file system - LocalMetadataProvider._persist_tags_for_run( + cls._persist_tags_for_run( flow_id, run_id, next_user_tags_set, existing_system_tag_set ) # read tags back from file system to see if our optimism is misplaced # I.e. did a concurrent mutate overwrite our change - run = LocalMetadataProvider.get_object( - "run", "self", {}, None, flow_id, run_id - ) + run = cls.get_object("run", "self", {}, None, flow_id, run_id) if not run: raise MetaflowTaggingError( msg="Run not found for read-back check (%s, %s)" % (flow_id, run_id) @@ -279,8 +281,6 @@ def _get_object_internal( if obj_type not in ("root", "flow", "run", "step", "task", "artifact"): raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) - from metaflow.plugins.datastores.local_storage import LocalStorage - if obj_type == "artifact": # Artifacts are actually part of the tasks in the filesystem # E.g. we get here for (obj_type, sub_type) == (artifact, self) @@ -307,13 +307,13 @@ def _get_object_internal( # Special handling of self, artifact, and metadata if sub_type == "self": - meta_path = LocalMetadataProvider._get_metadir(*args[:obj_order]) + meta_path = cls._get_metadir(*args[:obj_order]) if meta_path is None: return None self_file = os.path.join(meta_path, "_self.json") if os.path.isfile(self_file): obj = MetadataProvider._apply_filter( - [LocalMetadataProvider._read_json_file(self_file)], filters + [cls._read_json_file(self_file)], filters )[0] # For non-descendants of a run, we are done @@ -324,7 +324,7 @@ def _get_object_internal( raise MetaflowInternalError( msg="Unexpected object type %s" % obj_type ) - run = LocalMetadataProvider.get_object( + run = cls.get_object( "run", "self", {}, None, *args[:RUN_ORDER] # *[flow_id, run_id] ) if not run: @@ -341,7 +341,7 @@ def _get_object_internal( if obj_type not in ("root", "flow", "run", "step", "task"): raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) - meta_path = LocalMetadataProvider._get_metadir(*args[:obj_order]) + meta_path = cls._get_metadir(*args[:obj_order]) result = [] if meta_path is None: return result @@ -352,9 +352,7 @@ def _get_object_internal( attempts_done = sorted(glob.iglob(attempt_done_files)) if attempts_done: successful_attempt = int( - LocalMetadataProvider._read_json_file(attempts_done[-1])[ - "value" - ] + cls._read_json_file(attempts_done[-1])["value"] ) if successful_attempt is not None: which_artifact = "*" @@ -365,10 +363,10 @@ def _get_object_internal( "%d_artifact_%s.json" % (successful_attempt, which_artifact), ) for obj in glob.iglob(artifact_files): - result.append(LocalMetadataProvider._read_json_file(obj)) + result.append(cls._read_json_file(obj)) # We are getting artifacts. We should overlay with ancestral run's tags - run = LocalMetadataProvider.get_object( + run = cls.get_object( "run", "self", {}, None, *args[:RUN_ORDER] # *[flow_id, run_id] ) if not run: @@ -388,12 +386,12 @@ def _get_object_internal( if obj_type not in ("root", "flow", "run", "step", "task"): raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) result = [] - meta_path = LocalMetadataProvider._get_metadir(*args[:obj_order]) + meta_path = cls._get_metadir(*args[:obj_order]) if meta_path is None: return result files = os.path.join(meta_path, "sysmeta_*") for obj in glob.iglob(files): - result.append(LocalMetadataProvider._read_json_file(obj)) + result.append(cls._read_json_file(obj)) return result # For the other types, we locate all the objects we need to find and return them @@ -401,14 +399,13 @@ def _get_object_internal( raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) if sub_type not in ("flow", "run", "step", "task"): raise MetaflowInternalError(msg="unexpected sub type %s" % sub_type) - obj_path = LocalMetadataProvider._make_path( - *args[:obj_order], create_on_absent=False - ) + obj_path = cls._make_path(*args[:obj_order], create_on_absent=False) result = [] if obj_path is None: return result skip_dirs = "*/" * (sub_order - obj_order) - all_meta = os.path.join(obj_path, skip_dirs, LocalStorage.METADATA_DIR) + storage_class = cls._get_storage_class() + all_meta = os.path.join(obj_path, skip_dirs, storage_class.METADATA_DIR) SelfInfo = collections.namedtuple("SelfInfo", ["filepath", "run_id"]) self_infos = [] for meta_path in glob.iglob(all_meta): @@ -418,9 +415,7 @@ def _get_object_internal( run_id = None # flow and run do not need info from ancestral run if sub_type in ("step", "task"): - run_id = LocalMetadataProvider._deduce_run_id_from_meta_dir( - meta_path, sub_type - ) + run_id = cls._deduce_run_id_from_meta_dir(meta_path, sub_type) # obj_type IS run, or more granular than run, let's do sanity check vs args if obj_order >= RUN_ORDER: if run_id != args[RUN_ORDER - 1]: @@ -430,10 +425,10 @@ def _get_object_internal( self_infos.append(SelfInfo(filepath=self_file, run_id=run_id)) for self_info in self_infos: - obj = LocalMetadataProvider._read_json_file(self_info.filepath) + obj = cls._read_json_file(self_info.filepath) if self_info.run_id: flow_id_from_args = args[0] - run = LocalMetadataProvider.get_object( + run = cls.get_object( "run", "self", {}, @@ -452,8 +447,8 @@ def _get_object_internal( return MetadataProvider._apply_filter(result, filters) - @staticmethod - def _deduce_run_id_from_meta_dir(meta_dir_path, sub_type): + @classmethod + def _deduce_run_id_from_meta_dir(cls, meta_dir_path, sub_type): curr_order = ObjectOrder.type_to_order(sub_type) levels_to_ascend = curr_order - ObjectOrder.type_to_order("run") if levels_to_ascend < 0: @@ -468,8 +463,8 @@ def _deduce_run_id_from_meta_dir(meta_dir_path, sub_type): ) return run_id - @staticmethod - def _makedirs(path): + @classmethod + def _makedirs(cls, path): # this is for python2 compatibility. # Python3 has os.makedirs(exist_ok=True). try: @@ -481,17 +476,15 @@ def _makedirs(path): else: raise - @staticmethod - def _persist_tags_for_run(flow_id, run_id, tags, system_tags): - subpath = LocalMetadataProvider._create_and_get_metadir( - flow_name=flow_id, run_id=run_id - ) + @classmethod + def _persist_tags_for_run(cls, flow_id, run_id, tags, system_tags): + subpath = cls._create_and_get_metadir(flow_name=flow_id, run_id=run_id) selfname = os.path.join(subpath, "_self.json") if not os.path.isfile(selfname): raise MetaflowInternalError( msg="Could not verify Run existence on disk - missing %s" % selfname ) - LocalMetadataProvider._save_meta( + cls._save_meta( subpath, { "_self": MetadataProvider._run_to_json_static( @@ -508,11 +501,11 @@ def _ensure_meta( tags = set() if sys_tags is None: sys_tags = set() - subpath = self._create_and_get_metadir( + subpath = self.__class__._create_and_get_metadir( self._flow_name, run_id, step_name, task_id ) selfname = os.path.join(subpath, "_self.json") - self._makedirs(subpath) + self.__class__._makedirs(subpath) if os.path.isfile(selfname): # There is a race here, but we are not aiming to make this as solid as # the metadata service. This is used primarily for concurrent resumes, @@ -549,26 +542,31 @@ def _new_task( self._register_system_metadata(run_id, step_name, task_id, attempt) return to_return - @staticmethod + @classmethod def _make_path( - flow_name=None, run_id=None, step_name=None, task_id=None, create_on_absent=True + cls, + flow_name=None, + run_id=None, + step_name=None, + task_id=None, + create_on_absent=True, ): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() - if LocalStorage.datastore_root is None: + if storage_class.datastore_root is None: def print_clean(line, **kwargs): print(line) - LocalStorage.datastore_root = LocalStorage.get_datastore_root_from_config( + storage_class.datastore_root = storage_class.get_datastore_root_from_config( print_clean, create_on_absent=create_on_absent ) - if LocalStorage.datastore_root is None: + if storage_class.datastore_root is None: return None if flow_name is None: - return LocalStorage.datastore_root + return storage_class.datastore_root components = [] if flow_name: components.append(flow_name) @@ -578,37 +576,35 @@ def print_clean(line, **kwargs): components.append(step_name) if task_id: components.append(task_id) - return LocalStorage().full_uri(LocalStorage.path_join(*components)) + return storage_class().full_uri(storage_class.path_join(*components)) - @staticmethod + @classmethod def _create_and_get_metadir( - flow_name=None, run_id=None, step_name=None, task_id=None + cls, flow_name=None, run_id=None, step_name=None, task_id=None ): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() - root_path = LocalMetadataProvider._make_path( - flow_name, run_id, step_name, task_id - ) - subpath = os.path.join(root_path, LocalStorage.METADATA_DIR) - LocalMetadataProvider._makedirs(subpath) + root_path = cls._make_path(flow_name, run_id, step_name, task_id) + subpath = os.path.join(root_path, storage_class.METADATA_DIR) + cls._makedirs(subpath) return subpath - @staticmethod - def _get_metadir(flow_name=None, run_id=None, step_name=None, task_id=None): - from metaflow.plugins.datastores.local_storage import LocalStorage + @classmethod + def _get_metadir(cls, flow_name=None, run_id=None, step_name=None, task_id=None): + storage_class = cls._get_storage_class() - root_path = LocalMetadataProvider._make_path( + root_path = cls._make_path( flow_name, run_id, step_name, task_id, create_on_absent=False ) if root_path is None: return None - subpath = os.path.join(root_path, LocalStorage.METADATA_DIR) + subpath = os.path.join(root_path, storage_class.METADATA_DIR) if os.path.isdir(subpath): return subpath return None - @staticmethod - def _dump_json_to_file(filepath, data, allow_overwrite=False): + @classmethod + def _dump_json_to_file(cls, filepath, data, allow_overwrite=False): if os.path.isfile(filepath) and not allow_overwrite: return try: @@ -622,15 +618,13 @@ def _dump_json_to_file(filepath, data, allow_overwrite=False): if f and os.path.isfile(f.name): os.remove(f.name) - @staticmethod - def _read_json_file(filepath): + @classmethod + def _read_json_file(cls, filepath): with open(filepath, "r") as f: return json.load(f) - @staticmethod - def _save_meta(root_dir, metadict, allow_overwrite=False): + @classmethod + def _save_meta(cls, root_dir, metadict, allow_overwrite=False): for name, datum in metadict.items(): filename = os.path.join(root_dir, "%s.json" % name) - LocalMetadataProvider._dump_json_to_file( - filename, datum, allow_overwrite=allow_overwrite - ) + cls._dump_json_to_file(filename, datum, allow_overwrite=allow_overwrite) diff --git a/metaflow/plugins/metadata_providers/spin.py b/metaflow/plugins/metadata_providers/spin.py new file mode 100644 index 00000000000..ee77f2077b7 --- /dev/null +++ b/metaflow/plugins/metadata_providers/spin.py @@ -0,0 +1,16 @@ +from metaflow.plugins.metadata_providers.local import LocalMetadataProvider +from metaflow.metaflow_config import DATASTORE_SPIN_LOCAL_DIR + + +class SpinMetadataProvider(LocalMetadataProvider): + TYPE = "spin" + DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".spin_metaflow" + + @classmethod + def _get_storage_class(cls): + from metaflow.plugins.datastores.spin_storage import SpinStorage + + return SpinStorage + + def version(self): + return "spin" diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index f2f142a4653..240f124e0b9 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -506,14 +506,14 @@ async def __async_get_executing_task(self, attribute_file_fd, command_obj): ) return ExecutingTask(self, command_obj, task_object) - def spin(self, step_name, **kwargs): + def spin(self, pathspec, **kwargs): """ Blocking spin execution of the run. This method will wait until the spun run has completed execution. Parameters ---------- - step_name : str - The name of the step to spin. + pathspec : str + The pathspec of the step/task to spin. **kwargs : Any Additional arguments that you would pass to `python ./myflow.py` after the `spin` command. @@ -523,11 +523,19 @@ def spin(self, step_name, **kwargs): ExecutingTask containing the results of the spun task. """ with temporary_fifo() as (attribute_file_path, attribute_file_fd): - command = self.api(**self.top_level_kwargs).spin( - step_name=step_name, - runner_attribute_file=attribute_file_path, - **kwargs, - ) + if CLICK_API_PROCESS_CONFIG: + with with_dir(self.cwd): + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + else: + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) pid = self.spm.run_command( [sys.executable, *command], @@ -652,7 +660,7 @@ async def async_resume(self, **kwargs) -> ExecutingRun: return await self.__async_get_executing_run(attribute_file_fd, command_obj) - async def async_spin(self, step_name, spin_pathspec, **kwargs) -> ExecutingTask: + async def async_spin(self, pathspec, **kwargs) -> ExecutingTask: """ Non-blocking spin execution of the run. This method will return as soon as the spun task has launched. @@ -661,8 +669,8 @@ async def async_spin(self, step_name, spin_pathspec, **kwargs) -> ExecutingTask: Parameters ---------- - step_name : str - The name of the step to spin. + pathspec : str + The pathspec of the step/task to spin. **kwargs : Any Additional arguments that you would pass to `python ./myflow.py` after the `spin` command. @@ -673,11 +681,19 @@ async def async_spin(self, step_name, spin_pathspec, **kwargs) -> ExecutingTask: ExecutingTask representing the spun task that was started. """ with temporary_fifo() as (attribute_file_path, attribute_file_fd): - command = self.api(**self.top_level_kwargs).spin( - step_name=step_name, - runner_attribute_file=attribute_file_path, - **kwargs, - ) + if CLICK_API_PROCESS_CONFIG: + with with_dir(self.cwd): + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + else: + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) pid = await self.spm.async_run_command( [sys.executable, *command], diff --git a/metaflow/runtime.py b/metaflow/runtime.py index e6e659b2559..de6df4bc02c 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -129,18 +129,36 @@ def __init__( self._step_func = step_func - # Verify whether the use has provided step-name or spin-pathspec - if not spin_pathspec: - task = get_latest_task_pathspec(flow.name, step_name) - logger("For faster spin, use --spin-pathspec %s" % task.pathspec) - else: - # The user already provided a spin-pathspec, verify if its valid - try: - task = Task(spin_pathspec, _namespace_check=False) - except Exception: + # Determine if we have a complete pathspec or need to get the task + if spin_pathspec: + parts = spin_pathspec.split("/") + if len(parts) == 4: + # Complete pathspec: flow/run/step/task_id + try: + task = Task(spin_pathspec, _namespace_check=False) + except Exception: + raise MetaflowException( + f"Invalid pathspec: {spin_pathspec} for step: {step_name}" + ) + elif len(parts) == 3: + # Partial pathspec: flow/run/step - need to get the task + _, run_id, _ = parts + task = get_latest_task_pathspec(flow.name, step_name, run_id=run_id) + logger( + f"To make spin even faster, provide complete pathspec with task_id: {task.pathspec}", + system_msg=True, + ) + else: raise MetaflowException( - f"Invalid spin-pathspec: {spin_pathspec} for step: {step_name}" + f"Invalid pathspec format: {spin_pathspec}. Expected flow/run/step or flow/run/step/task_id" ) + else: + # No pathspec provided, get latest task for this step + task = get_latest_task_pathspec(flow.name, step_name) + logger( + f"To make spin even faster, provide complete pathspec {task.pathspec}", + system_msg=True, + ) from_start("SpinRuntime: after getting task") # Get the original FlowDatastore so we can use it to access artifacts from the @@ -181,6 +199,35 @@ def __init__( self._max_log_size = max_log_size self._encoding = sys.stdout.encoding or "UTF-8" + # If no artifacts module is provided, create a temporary one with parameter values + if not self._artifacts_module and hasattr(flow, "_get_parameters"): + import tempfile + import os + + # Collect parameter values from the flow + param_artifacts = {} + for var, param in flow._get_parameters(): + if hasattr(flow, var): + value = getattr(flow, var) + # Only add if it's an actual value, not the Parameter object + if value is not None and not hasattr(value, "IS_PARAMETER"): + param_artifacts[var] = value + + # If we have parameter values, create a temp module + if param_artifacts: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as f: + f.write( + "# Auto-generated artifacts module for spin step parameters\n" + ) + f.write("ARTIFACTS = {\n") + for key, value in param_artifacts.items(): + f.write(f" {repr(key)}: {repr(value)},\n") + f.write("}\n") + self._artifacts_module = f.name + self._temp_artifacts_file = f.name # Store for cleanup later + # Create a new run_id for the spin task self.run_id = self._metadata.new_run_id() for deco in self.whitelist_decorators: @@ -276,6 +323,14 @@ def execute(self): finally: for deco in self.whitelist_decorators: deco.runtime_finished(exception) + # Clean up temporary artifacts file if we created one + if hasattr(self, "_temp_artifacts_file"): + import os + + try: + os.unlink(self._temp_artifacts_file) + except: + pass def _launch_and_monitor_task(self): worker = Worker( @@ -284,9 +339,9 @@ def _launch_and_monitor_task(self): self._config_file_name, orig_flow_datastore=self._orig_flow_datastore, spin_pathspec=self._spin_pathspec, - whitelist_decorators=self.whitelist_decorators, artifacts_module=self._artifacts_module, persist=self._persist, + skip_decorators=self._skip_decorators, ) from_start("SpinRuntime: created worker") @@ -2003,9 +2058,9 @@ def __init__( task, orig_flow_datastore=None, spin_pathspec=None, - whitelist_decorators=None, artifacts_module=None, persist=True, + skip_decorators=False, ): self.task = task if orig_flow_datastore is not None: @@ -2016,9 +2071,9 @@ def __init__( else: self.orig_flow_datastore = None self.spin_pathspec = spin_pathspec - self.whitelist_decorators = whitelist_decorators self.artifacts_module = artifacts_module self.persist = persist + self.skip_decorators = skip_decorators self.entrypoint = list(task.entrypoint) step_obj = getattr(self.task.flow, self.task.step) self.top_level_options = { @@ -2081,8 +2136,6 @@ def spin_args(self): self.commands = ["spin-step"] self.command_args = [self.task.step] - whitelist_decos = [deco.name for deco in self.whitelist_decorators] - self.command_options = { "run-id": self.task.run_id, "task-id": self.task.task_id, @@ -2093,8 +2146,8 @@ def spin_args(self): "namespace": get_namespace() or "", "orig-flow-datastore": self.orig_flow_datastore, "spin-pathspec": self.spin_pathspec, - "whitelist-decorators": compress_list(whitelist_decos), "artifacts-module": self.artifacts_module, + "skip-decorators": self.skip_decorators, } if self.persist: self.command_options["persist"] = True @@ -2145,16 +2198,16 @@ def __init__( config_file_name, orig_flow_datastore=None, spin_pathspec=None, - whitelist_decorators=None, artifacts_module=None, persist=True, + skip_decorators=False, ): self.task = task self._config_file_name = config_file_name self._orig_flow_datastore = orig_flow_datastore self._spin_pathspec = spin_pathspec - self._whitelist_decorators = whitelist_decorators self._artifacts_module = artifacts_module + self._skip_decorators = skip_decorators self._persist = persist self._proc = self._launch() @@ -2191,9 +2244,9 @@ def _launch(self): self.task, orig_flow_datastore=self._orig_flow_datastore, spin_pathspec=self._spin_pathspec, - whitelist_decorators=self._whitelist_decorators, artifacts_module=self._artifacts_module, persist=self._persist, + skip_decorators=self._skip_decorators, ) env = dict(os.environ) diff --git a/metaflow/task.py b/metaflow/task.py index a811bec485c..52722a1be58 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -11,6 +11,7 @@ from metaflow.sidecar import Message, MessageTypes from metaflow.datastore.exceptions import DataException +from metaflow.plugins import METADATA_PROVIDERS from .metaflow_config import MAX_ATTEMPTS from .metadata_provider import MetaDatum from .metaflow_profile import from_start @@ -853,10 +854,17 @@ def run_step( ) from_start("MetaflowTask: before pre-step decorators") for deco in decorators: + if deco.name == "card" and self.orig_flow_datastore: + # if spin step and card decorator, pass spin metadata + metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][ + 0 + ](self.environment, self.flow, self.event_logger, self.monitor) + else: + metadata = self.metadata deco.task_pre_step( step_name, output, - self.metadata, + metadata, run_id, task_id, self.flow, diff --git a/metaflow/util.py b/metaflow/util.py index ed34b8802e0..03b8ce916e6 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -9,7 +9,7 @@ from functools import wraps from io import BytesIO from itertools import takewhile -import re +from typing import Dict, Any, Tuple, Optional, List try: @@ -180,6 +180,117 @@ def resolve_identity(): return "%s:%s" % (identity_type, identity_value) +def parse_spin_pathspec(pathspec: str, flow_name: str) -> Tuple: + """ + Parse various pathspec formats for the spin command. + + Parameters + ---------- + pathspec : str + The pathspec string in one of the following formats: + - step_name (e.g., 'start') + - run_id/step_name (e.g., '221165/start') + - run_id/step_name/task_id (e.g., '221165/start/1350987') + - flow_name/run_id/step_name (e.g., 'ScalableFlow/221165/start') + - flow_name/run_id/step_name/task_id (e.g., 'ScalableFlow/221165/start/1350987') + flow_name : str + The name of the current flow. + + Returns + ------- + Tuple + A tuple of (step_name, full_pathspec_or_none) + + Raises + ------ + CommandException + If the pathspec format is invalid or flow name doesn't match. + """ + from .exception import CommandException + + parts = pathspec.split("/") + + if len(parts) == 1: + # Just step name: 'start' + step_name = parts[0] + parsed_pathspec = None + elif len(parts) == 2: + # run_id/step_name: '221165/start' + run_id, step_name = parts + parsed_pathspec = f"{flow_name}/{run_id}/{step_name}" + elif len(parts) == 3: + # Could be run_id/step_name/task_id or flow_name/run_id/step_name + if parts[0] == flow_name: + # flow_name/run_id/step_name + _, run_id, step_name = parts + parsed_pathspec = f"{flow_name}/{run_id}/{step_name}" + else: + # run_id/step_name/task_id + run_id, step_name, task_id = parts + parsed_pathspec = f"{flow_name}/{run_id}/{step_name}/{task_id}" + elif len(parts) == 4: + # flow_name/run_id/step_name/task_id + parsed_flow_name, run_id, step_name, task_id = parts + if parsed_flow_name != flow_name: + raise CommandException( + f"Flow name '{parsed_flow_name}' in pathspec does not match current flow '{flow_name}'." + ) + parsed_pathspec = pathspec + else: + raise CommandException( + f"Invalid pathspec format: '{pathspec}'. \n" + "Expected formats:\n" + " - step_name (e.g., 'start')\n" + " - run_id/step_name (e.g., '221165/start')\n" + " - run_id/step_name/task_id (e.g., '221165/start/1350987')\n" + " - flow_name/run_id/step_name (e.g., 'ScalableFlow/221165/start')\n" + " - flow_name/run_id/step_name/task_id (e.g., 'ScalableFlow/221165/start/1350987')" + ) + + return step_name, parsed_pathspec + + +def get_latest_task_pathspec(flow_name: str, step_name: str, run_id: str = None) -> "metaflow.Task": + """ + Returns a task pathspec from the latest run (or specified run) of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + run_id : str, optional + The run ID to use. If None, uses the latest run. + + Returns + ------- + Task + A Metaflow Task instance containing the latest task for the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + if not run_id: + flow = Flow(flow_name, _namespace_check=False) + run = flow.latest_run + if run is None: + raise MetaflowNotFound(f"No run found for flow {flow_name}") + run_id = run.id + + try: + task = Step(f"{flow_name}/{run_id}/{step_name}", _namespace_check=False).task + return task + except: + raise MetaflowNotFound(f"No task found for step {step_name} in run {run_id}") + + def get_latest_run_id(echo, flow_name): from metaflow.plugins.datastores.local_storage import LocalStorage @@ -475,6 +586,41 @@ def to_pod(value): from metaflow._vendor.packaging.version import parse as version_parse +def read_artifacts_module(file_path: str) -> Dict[str, Any]: + """ + Read a Python module from the given file path and return its ARTIFACTS variable. + + Parameters + ---------- + file_path : str + The path to the Python file containing the ARTIFACTS variable. + + Returns + ------- + Dict[str, Any] + A dictionary containing the ARTIFACTS variable from the module. + + Raises + ------- + MetaflowInternalError + If the file cannot be read or does not contain the ARTIFACTS variable. + """ + import importlib.util + + try: + spec = importlib.util.spec_from_file_location("artifacts_module", file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + variables = vars(module) + if "ARTIFACTS" not in variables: + raise MetaflowInternalError( + f"Module {file_path} does not contain ARTIFACTS variable" + ) + return variables.get("ARTIFACTS") + except Exception as e: + raise MetaflowInternalError(f"Error reading file {file_path}") from e + + # this is os.walk(follow_symlinks=True) with cycle detection def walk_without_cycles( top_root: str, diff --git a/test/unit/spin/conftest.py b/test/unit/spin/conftest.py index 6c084b1c375..b98bbca1d9f 100644 --- a/test/unit/spin/conftest.py +++ b/test/unit/spin/conftest.py @@ -1,33 +1,78 @@ import pytest -from metaflow import Runner +from metaflow import Runner, Flow import os # Get the directory containing the flows FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") -@pytest.fixture(scope="session") -def complex_dag_run(): - """Run ComplexDAGFlow and return the completed run.""" - flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") - with Runner(flow_path, environment="conda").run() as running: - return running.run +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption( + "--use-latest", + action="store_true", + default=False, + help="Use latest run of each flow instead of running new ones", + ) -@pytest.fixture(scope="session") -def merge_artifacts_run(): - """Run MergeArtifactsFlow and return the completed run.""" - flow_path = os.path.join(FLOWS_DIR, "merge_artifacts_flow.py") - with Runner(flow_path).run() as running: - return running.run +def create_flow_fixture(flow_name, flow_file, run_params=None, runner_params=None): + """Factory function to create flow fixtures with common logic. + Args: + flow_name: Name of the flow class + flow_file: Python file containing the flow + run_params: Parameters to pass to .run() method + runner_params: Parameters to pass to Runner() constructor + """ -@pytest.fixture(scope="session") -def simple_parameter_run(): - """Run SimpleParameterFlow and return the completed run.""" - flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") - with Runner(flow_path).run(alpha=0.05) as running: - return running.run + def flow_fixture(request): + if request.config.getoption("--use-latest"): + flow = Flow(flow_name, _namespace_check=False) + return flow.latest_run + else: + flow_path = os.path.join(FLOWS_DIR, flow_file) + runner_params_dict = runner_params or {} + runner_params_dict["cwd"] = FLOWS_DIR # Always set cwd to FLOWS_DIR + run_params_dict = run_params or {} + + with Runner(flow_path, **runner_params_dict).run( + **run_params_dict + ) as running: + return running.run + + return flow_fixture + + +# Create all the flow fixtures using the factory +complex_dag_run = pytest.fixture(scope="session")( + create_flow_fixture( + "ComplexDAGFlow", "complex_dag_flow.py", runner_params={"environment": "conda"} + ) +) + +merge_artifacts_run = pytest.fixture(scope="session")( + create_flow_fixture("MergeArtifactsFlow", "merge_artifacts_flow.py") +) + +simple_parameter_run = pytest.fixture(scope="session")( + create_flow_fixture( + "SimpleParameterFlow", "simple_parameter_flow.py", run_params={"alpha": 0.05} + ) +) + +simple_card_run = pytest.fixture(scope="session")( + create_flow_fixture( + "SimpleCardFlow", "simple_card_flow.py", run_params={"alpha": 0.05} + ) +) + +simple_config_run = pytest.fixture(scope="session")( + create_flow_fixture( + "TimeoutConfigFlow", + "simple_config_flow.py", + ) +) @pytest.fixture diff --git a/test/unit/spin/flows/merge_artifacts_flow.py b/test/unit/spin/flows/merge_artifacts_flow.py index 59f1390e052..fe49f8c10be 100644 --- a/test/unit/spin/flows/merge_artifacts_flow.py +++ b/test/unit/spin/flows/merge_artifacts_flow.py @@ -25,7 +25,6 @@ def b(self): @step def join(self, inputs): - print(f"In join step, self._datastore: {(type(self._datastore))}") self.x = inputs.a.x self.merge_artifacts(inputs, exclude=["y"]) print("x is %s" % self.x) diff --git a/test/unit/spin/flows/myconfig.json b/test/unit/spin/flows/myconfig.json new file mode 100644 index 00000000000..c24b31c1e41 --- /dev/null +++ b/test/unit/spin/flows/myconfig.json @@ -0,0 +1 @@ +{"timeout": 60} \ No newline at end of file diff --git a/test/unit/spin/flows/simple_card_flow.py b/test/unit/spin/flows/simple_card_flow.py new file mode 100644 index 00000000000..4da026aba8d --- /dev/null +++ b/test/unit/spin/flows/simple_card_flow.py @@ -0,0 +1,27 @@ +from metaflow import FlowSpec, step, card, Parameter, current +from metaflow.cards import Markdown + + +class SimpleCardFlow(FlowSpec): + + number = Parameter("number", default=3) + + @card(type="blank") + @step + def start(self): + current.card.append(Markdown("# Guess my number")) + if self.number > 5: + current.card.append(Markdown("My number is **smaller** ⬇️")) + elif self.number < 5: + current.card.append(Markdown("My number is **larger** ⬆️")) + else: + current.card.append(Markdown("## Correct! 🎉")) + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + SimpleCardFlow() diff --git a/test/unit/spin/flows/simple_config_flow.py b/test/unit/spin/flows/simple_config_flow.py new file mode 100644 index 00000000000..d4e910e1e29 --- /dev/null +++ b/test/unit/spin/flows/simple_config_flow.py @@ -0,0 +1,22 @@ +import time +from metaflow import FlowSpec, step, Config, timeout + + +class TimeoutConfigFlow(FlowSpec): + config = Config("config", default="myconfig.json") + + @timeout(seconds=config.timeout) + @step + def start(self): + print(f"timing out after {self.config.timeout} seconds") + time.sleep(5) + print("success") + self.next(self.end) + + @step + def end(self): + print("full config", self.config) + + +if __name__ == "__main__": + TimeoutConfigFlow() diff --git a/test/unit/spin/flows/simple_parameter_flow.py b/test/unit/spin/flows/simple_parameter_flow.py index 0a36ff007ce..bf1969326a2 100644 --- a/test/unit/spin/flows/simple_parameter_flow.py +++ b/test/unit/spin/flows/simple_parameter_flow.py @@ -1,10 +1,10 @@ -from metaflow import FlowSpec, step, Parameter, titus +from metaflow import FlowSpec, step, Parameter, current, project +@project(name="simple_parameter_flow") class SimpleParameterFlow(FlowSpec): alpha = Parameter("alpha", help="Learning rate", default=0.01) - @titus @step def start(self): print("SimpleParameterFlow is starting.") @@ -19,6 +19,11 @@ def end(self): self.x = 100 self.y = 200 print("Parameter alpha in end step is: ", self.alpha) + print( + f"Pathspec: {current.pathspec}, flow_name: {current.flow_name}, run_id: {current.run_id}" + ) + print(f"step_name: {current.step_name}, task_id: {current.task_id}") + print(f"Project name: {current.project_name}, Namespace: {current.namespace}") del self.a del self.x print("SimpleParameterFlow is all done.") diff --git a/test/unit/spin/spin_test_helpers.py b/test/unit/spin/spin_test_helpers.py new file mode 100644 index 00000000000..b2c61e37457 --- /dev/null +++ b/test/unit/spin/spin_test_helpers.py @@ -0,0 +1,32 @@ +import os +from metaflow import Runner + +FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") +ARTIFACTS_DIR = os.path.join(os.path.dirname(__file__), "artifacts") + + +def assert_artifacts(task, spin_task): + """Assert that artifacts match between original task and spin task.""" + spin_task_artifacts = { + artifact.id: artifact.data for artifact in spin_task.artifacts + } + print(f"Spin task artifacts: {spin_task_artifacts}") + for artifact in task.artifacts: + assert ( + artifact.id in spin_task_artifacts + ), f"Artifact {artifact.id} not found in spin task" + assert ( + artifact.data == spin_task_artifacts[artifact.id] + ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" + + +def run_step(flow_file, run, step_name, **tl_kwargs): + """Run a step and assert artifacts match.""" + task = run[step_name].task + flow_path = os.path.join(FLOWS_DIR, flow_file) + print(f"FLOWS_DIR: {FLOWS_DIR}") + + with Runner(flow_path, cwd=FLOWS_DIR, **tl_kwargs).spin(task.pathspec) as spin: + print("-" * 50) + print(f"Running test for step: {step_name} with task pathspec: {task.pathspec}") + assert_artifacts(task, spin.task) diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 63408bbeee1..9f92e24ad09 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -1,68 +1,29 @@ import pytest from metaflow import Runner import os - -FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") -ARTIFACTS_DIR = os.path.join(os.path.dirname(__file__), "artifacts") - - -def _assert_artifacts(task, spin_task): - spin_task_artifacts = { - artifact.id: artifact.data for artifact in spin_task.artifacts - } - print(f"Spin task artifacts: {spin_task_artifacts}") - for artifact in task.artifacts: - assert ( - artifact.id in spin_task_artifacts - ), f"Artifact {artifact.id} not found in spin task" - assert ( - artifact.data == spin_task_artifacts[artifact.id] - ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" - - -def _run_step(flow_file, run, step_name, is_conda=False): - task = run[step_name].task - flow_path = os.path.join(FLOWS_DIR, flow_file) - - if not is_conda: - with Runner(flow_path).spin(step_name, spin_pathspec=task.pathspec) as spin: - print("-" * 50) - print( - f"Running test for step: {step_name} with task pathspec: {task.pathspec}" - ) - _assert_artifacts(task, spin.task) - else: - with Runner(flow_path, environment="conda").spin( - step_name, - spin_pathspec=task.pathspec, - ) as spin: - print("-" * 50) - print( - f"Running test for step: {step_name} with task pathspec: {task.pathspec}" - ) - print(f"Spin task artifacts: {spin.task.artifacts}") - _assert_artifacts(task, spin.task) - - -def test_complex_dag_flow(complex_dag_run): - print(f"Running test for ComplexDAGFlow flow: {complex_dag_run}") - for step in complex_dag_run.steps(): - print("-" * 100) - _run_step("complex_dag_flow.py", complex_dag_run, step.id, is_conda=True) - - -def test_merge_artifacts_flow(merge_artifacts_run): - print(f"Running test for merge artifacts flow: {merge_artifacts_run}") - for step in merge_artifacts_run.steps(): - print("-" * 100) - _run_step("merge_artifacts_flow.py", merge_artifacts_run, step.id) - - -def test_simple_parameter_flow(simple_parameter_run): - print(f"Running test for SimpleParameterFlow: {simple_parameter_run}") - for step in simple_parameter_run.steps(): +from spin_test_helpers import assert_artifacts, run_step, FLOWS_DIR, ARTIFACTS_DIR + + +@pytest.mark.parametrize( + "flow_file,fixture_name", + [ + ("merge_artifacts_flow.py", "merge_artifacts_run"), + ("simple_config_flow.py", "simple_config_run"), + ("simple_parameter_flow.py", "simple_parameter_run"), + ("complex_dag_flow.py", "complex_dag_run"), + ], + ids=["merge_artifacts", "simple_config", "simple_parameter", "complex_dag"], +) +def test_simple_flows(flow_file, fixture_name, request): + """Test simple flows that just need artifact validation.""" + run = request.getfixturevalue(fixture_name) + print(f"Running test for {flow_file}: {run}") + for step in run.steps(): print("-" * 100) - _run_step("simple_parameter_flow.py", simple_parameter_run, step.id) + if fixture_name == "complex_dag_run": + run_step(flow_file, run, step.id, environment="conda") + else: + run_step(flow_file, run, step.id) def test_artifacts_module(complex_dag_run): @@ -73,8 +34,7 @@ def test_artifacts_module(complex_dag_run): artifacts_path = os.path.join(ARTIFACTS_DIR, "complex_dag_step_a.py") with Runner(flow_path, environment="conda").spin( - step_name, - spin_pathspec=task.pathspec, + task.pathspec, artifacts_module=artifacts_path, ) as spin: print("-" * 50) @@ -97,8 +57,7 @@ def test_artifacts_module_join_step( temp_artifacts_file.write_text(f"ARTIFACTS = {repr(complex_dag_step_d_artifacts)}") with Runner(flow_path, environment="conda").spin( - step_name, - spin_pathspec=task.pathspec, + task.pathspec, artifacts_module=str(temp_artifacts_file), ) as spin: print("-" * 50) @@ -107,43 +66,64 @@ def test_artifacts_module_join_step( assert spin_task["my_output"].data == [-1] -def test_skip_decorators(complex_dag_run): - print(f"Running test for skip decorator in ComplexDAGFlow: {complex_dag_run}") - step_name = "step_m" - task = complex_dag_run[step_name].task - flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") +def test_timeout_decorator_enforcement(simple_config_run): + """Test that timeout decorator properly enforces timeout limits.""" + step_name = "start" + task = simple_config_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_config_flow.py") + + # With decorator enabled (should timeout and raise exception) + with pytest.raises(Exception): + with Runner( + flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] + ).spin( + task.pathspec, + ): + pass + + +def test_skip_decorators_bypass(simple_config_run): + """Test that skip_decorators successfully bypasses timeout decorator.""" + step_name = "start" + task = simple_config_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_config_flow.py") + + # With skip_decorators=True (should succeed despite timeout) + with Runner( + flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] + ).spin( + task.pathspec, + skip_decorators=True, + ) as spin: + print(f"Running test for step: {step_name} with skip_decorators=True") + # Should complete successfully even though sleep(5) > timeout(2) + spin_task = spin.task + assert spin_task.finished + + +def test_hidden_artifacts(simple_parameter_run): + """Test simple flows that just need artifact validation.""" + step_name = "start" + task = simple_parameter_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + print(f"Running test for hidden artifacts in {flow_path}: {simple_parameter_run}") + + with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec) as spin: + spin_task = spin.task + assert "_graph_info" in spin_task + assert "_foreach_stack" in spin_task + + +def test_card_flow(simple_card_run): + """Test a simple flow that has @card decorator.""" + step_name = "start" + task = simple_card_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_card_flow.py") + print(f"Running test for cards in {flow_path}: {simple_card_run}") + + with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec) as spin: + spin_task = spin.task + from metaflow.cards import get_cards - # Check if sklearn is available in the outer environment - is_sklearn = True - try: - import sklearn - except ImportError: - is_sklearn = False - - if is_sklearn: - # We verify that the sklearn version is the same as the one in the outside environment - with Runner(flow_path, environment="conda").spin( - step_name, - spin_pathspec=task.pathspec, - skip_decorators=True, - ) as spin: - print("-" * 50) - print( - f"Running test for step: {step_name} with task pathspec: {task.pathspec}" - ) - spin_task = spin.task - import sklearn - - expected_version = sklearn.__version__ - assert ( - spin_task["sklearn_version"].data == expected_version - ), f"Expected sklearn version {expected_version} but got {spin_task['sklearn_version']}" - else: - # We assert that an exception is raised when trying to run the step with skip_decorators=True - with pytest.raises(Exception): - with Runner(flow_path, environment="conda").spin( - step_name, - spin_pathspec=task.pathspec, - skip_decorators=True, - ): - pass + res = get_cards(spin_task, follow_resumed=False) + print(res) From 1b491271e1d6c5308adf42aabe555f8ddf2d81b7 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 1 Oct 2025 17:09:56 -0700 Subject: [PATCH 07/21] Address feedback --- .gitignore | 2 +- metaflow/__init__.py | 1 + metaflow/cli.py | 13 +++++-- metaflow/cli_components/run_cmds.py | 1 - metaflow/cli_components/step_cmd.py | 7 ---- metaflow/client/__init__.py | 1 + metaflow/client/core.py | 9 +++++ metaflow/decorators.py | 5 +-- metaflow/metaflow_config.py | 2 +- metaflow/plugins/cards/card_cli.py | 2 -- metaflow/plugins/cards/card_modules/basic.py | 10 ++++-- metaflow/plugins/datastores/spin_storage.py | 2 +- metaflow/plugins/metadata_providers/spin.py | 2 +- metaflow/runtime.py | 3 +- test/unit/spin/conftest.py | 22 +++++++----- test/unit/spin/flows/hello_spin_flow.py | 26 ++++++++++++++ test/unit/spin/flows/simple_card_flow.py | 21 +++++++++++- test/unit/spin/test_spin.py | 36 ++++++++++++++++++++ 18 files changed, 132 insertions(+), 33 deletions(-) create mode 100644 test/unit/spin/flows/hello_spin_flow.py diff --git a/.gitignore b/.gitignore index 3557ef49ad6..2e62c56aced 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class *.metaflow -*.spin_metaflow +*.metaflow_spin metaflow_card_cache/ build/ diff --git a/metaflow/__init__.py b/metaflow/__init__.py index 0eba0da3f33..9a0b005e286 100644 --- a/metaflow/__init__.py +++ b/metaflow/__init__.py @@ -146,6 +146,7 @@ class and related decorators. metadata, get_metadata, default_metadata, + inspect_spin, Metaflow, Flow, Run, diff --git a/metaflow/cli.py b/metaflow/cli.py index fc9e853ad4f..63fb39c925a 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -322,6 +322,13 @@ def version(obj): hidden=True, is_eager=True, ) +@click.option( + "--spin-mode", + is_flag=True, + default=False, + help="Enable spin mode for metaflow cli commands. Setting this flag will result " + "in using spin metadata and spin datastore for executions" +) @click.pass_context def start( ctx, @@ -339,6 +346,7 @@ def start( local_config_file=None, config=None, config_value=None, + spin_mode=False, **deco_options ): if quiet: @@ -371,6 +379,7 @@ def start( ctx.obj.check = functools.partial(_check, echo) ctx.obj.top_cli = cli ctx.obj.package_suffixes = package_suffixes.split(",") + ctx.obj.spin_mode = spin_mode ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == datastore][0] @@ -499,8 +508,8 @@ def start( ctx.obj.is_spin = False ctx.obj.skip_decorators = False - # Override values for spin - if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: + # Override values for spin steps, or if we are in spin mode + if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0] or ctx.obj.spin_mode: # To minimize side effects for spin, we will only use the following: # - local metadata provider, # - local datastore, diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index c957acc22a4..364ddd80b8b 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -526,7 +526,6 @@ def spin( "flow_name": obj.flow.name, # Store metadata in a format that can be used by the Runner API "metadata": f"{obj.metadata.__class__.TYPE}@{orig_task_metadata_root}", - # "metadata": f"spin@{orig_task_metadata_root}", }, f, ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 0a36792dca0..41a9a67674f 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -202,12 +202,6 @@ def step( show_default=True, help="Original datastore for the flow from which a task is being spun", ) -@click.option( - "--spin-pathspec", - default=None, - show_default=True, - help="Task Pathspec to be used in the spun step.", -) @click.option( "--input-paths", help="A comma-separated list of pathspecs specifying inputs for this step.", @@ -266,7 +260,6 @@ def spin_step( run_id=None, task_id=None, orig_flow_datastore=None, - spin_pathspec=None, input_paths=None, split_index=None, retry_count=None, diff --git a/metaflow/client/__init__.py b/metaflow/client/__init__.py index a06fbd290ba..9acf7c44c88 100644 --- a/metaflow/client/__init__.py +++ b/metaflow/client/__init__.py @@ -6,6 +6,7 @@ metadata, get_metadata, default_metadata, + inspect_spin, Metaflow, Flow, Run, diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 58aecb12afa..b6872cd3044 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -207,6 +207,15 @@ def default_namespace() -> str: return get_namespace() +def inspect_spin(datastore_root): + """ + Set metadata provider to spin metadata so that users can inspect spin + steps, tasks, and artifacts. + """ + metadata_str = f"spin@{datastore_root}" + metadata(metadata_str) + + MetaflowArtifacts = NamedTuple diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 267d875f684..d6f91cd3066 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -692,10 +692,7 @@ def _should_skip_decorator_for_spin( # Run decorator hooks for spin steps only if they are in the whitelist if deco.name not in SPIN_ALLOWED_DECORATORS: logger( - f"[Warning] {decorator_type} '{deco.name}' is not supported in spin steps. " - f"Supported decorators are: [{', '.join(SPIN_ALLOWED_DECORATORS)}]. " - f"Skipping this decorator as it is not in the whitelist.\n" - f"Alternatively, you can use the --skip-decorators flag to skip running all decorators in spin steps.", + f"[Warning] Ignoring {decorator_type} '{deco.name}' as it is not supported in spin steps.", system_msg=True, timestamp=False, bad=True, diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 0dc4526833c..21aa0b40acd 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -21,7 +21,7 @@ # Path to the local directory to store artifacts for 'local' datastore. DATASTORE_LOCAL_DIR = ".metaflow" -DATASTORE_SPIN_LOCAL_DIR = ".spin_metaflow" +DATASTORE_SPIN_LOCAL_DIR = ".metaflow_spin" # Local configuration file (in .metaflow) containing overrides per-project LOCAL_CONFIG_FILE = "config.json" diff --git a/metaflow/plugins/cards/card_cli.py b/metaflow/plugins/cards/card_cli.py index 9cb8b4bbb9d..766f23a70d8 100644 --- a/metaflow/plugins/cards/card_cli.py +++ b/metaflow/plugins/cards/card_cli.py @@ -335,7 +335,6 @@ def list_many_cards( def cli(): pass - @cli.group(help="Commands related to @card decorator.") @click.pass_context def card(ctx): @@ -343,7 +342,6 @@ def card(ctx): # Can work with the Metaflow client. # If we don't set the metadata here than the metaflow client picks the defaults when calling the `Task`/`Run` objects. These defaults can come from the `config.json` file or based on the `METAFLOW_PROFILE` from metaflow import metadata - setting_metadata = "@".join( [ctx.obj.metadata.TYPE, ctx.obj.metadata.default_info()] ) diff --git a/metaflow/plugins/cards/card_modules/basic.py b/metaflow/plugins/cards/card_modules/basic.py index 74dbe32d5a6..1081a2b916b 100644 --- a/metaflow/plugins/cards/card_modules/basic.py +++ b/metaflow/plugins/cards/card_modules/basic.py @@ -496,9 +496,13 @@ def render(self): ) # ignore the name as a parameter - param_ids = [ - p.id for p in self._task.parent.parent["_parameters"].task if p.id != "name" - ] + if "_parameters" not in self._task.parent.parent: + # In case of spin steps, there is no _parameters task + param_ids = [] + else: + param_ids = [ + p.id for p in self._task.parent.parent["_parameters"].task if p.id != "name" + ] if len(param_ids) > 0: # Extract parameter from the Parameter Task. That is less brittle. parameter_data = TaskToDict( diff --git a/metaflow/plugins/datastores/spin_storage.py b/metaflow/plugins/datastores/spin_storage.py index 29fd5bdcc30..d0f39baf62b 100644 --- a/metaflow/plugins/datastores/spin_storage.py +++ b/metaflow/plugins/datastores/spin_storage.py @@ -8,5 +8,5 @@ class SpinStorage(LocalStorage): TYPE = "spin" METADATA_DIR = "_meta" - DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".spin_metaflow" + DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".metaflow_spin" SYSROOT_VAR = DATASTORE_SYSROOT_SPIN diff --git a/metaflow/plugins/metadata_providers/spin.py b/metaflow/plugins/metadata_providers/spin.py index ee77f2077b7..e32fdc8ffe6 100644 --- a/metaflow/plugins/metadata_providers/spin.py +++ b/metaflow/plugins/metadata_providers/spin.py @@ -4,7 +4,7 @@ class SpinMetadataProvider(LocalMetadataProvider): TYPE = "spin" - DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".spin_metaflow" + DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".metaflow_spin" @classmethod def _get_storage_class(cls): diff --git a/metaflow/runtime.py b/metaflow/runtime.py index de6df4bc02c..c1a577d37e2 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -2145,12 +2145,13 @@ def spin_args(self): "max-user-code-retries": self.task.user_code_retries, "namespace": get_namespace() or "", "orig-flow-datastore": self.orig_flow_datastore, - "spin-pathspec": self.spin_pathspec, "artifacts-module": self.artifacts_module, "skip-decorators": self.skip_decorators, } if self.persist: self.command_options["persist"] = True + else: + self.command_options["no-persist"] = True self.env = {} def get_args(self): diff --git a/test/unit/spin/conftest.py b/test/unit/spin/conftest.py index b98bbca1d9f..a4f90937543 100644 --- a/test/unit/spin/conftest.py +++ b/test/unit/spin/conftest.py @@ -17,13 +17,19 @@ def pytest_addoption(parser): def create_flow_fixture(flow_name, flow_file, run_params=None, runner_params=None): - """Factory function to create flow fixtures with common logic. - - Args: - flow_name: Name of the flow class - flow_file: Python file containing the flow - run_params: Parameters to pass to .run() method - runner_params: Parameters to pass to Runner() constructor + """ + Factory function to create flow fixtures with common logic. + + Parameters + ----------- + flow_name: str + Name of the flow class + flow_file: str + Python file containing the flow + run_params: dict, optional + Parameters to pass to .run() method + runner_params: dict, optional + Parameters to pass to Runner() """ def flow_fixture(request): @@ -63,7 +69,7 @@ def flow_fixture(request): simple_card_run = pytest.fixture(scope="session")( create_flow_fixture( - "SimpleCardFlow", "simple_card_flow.py", run_params={"alpha": 0.05} + "SimpleCardFlow", "simple_card_flow.py", ) ) diff --git a/test/unit/spin/flows/hello_spin_flow.py b/test/unit/spin/flows/hello_spin_flow.py new file mode 100644 index 00000000000..2df4a6aeee6 --- /dev/null +++ b/test/unit/spin/flows/hello_spin_flow.py @@ -0,0 +1,26 @@ +from metaflow import FlowSpec, step +import random + + +class HelloSpinFlow(FlowSpec): + + @step + def start(self): + chunk_size = 1024 * 1024 # 1 MB + total_size = 1024 * 1024 * 1000 # 1000 MB + + data = bytearray() + for _ in range(total_size // chunk_size): + data.extend(random.randbytes(chunk_size)) + + self.a = data + self.next(self.end) + + @step + def end(self): + print(f"Size of artifact a: {len(self.a)} bytes") + print("HelloSpinFlow completed.") + + +if __name__ == "__main__": + HelloSpinFlow() \ No newline at end of file diff --git a/test/unit/spin/flows/simple_card_flow.py b/test/unit/spin/flows/simple_card_flow.py index 4da026aba8d..83142d08ba8 100644 --- a/test/unit/spin/flows/simple_card_flow.py +++ b/test/unit/spin/flows/simple_card_flow.py @@ -1,10 +1,14 @@ from metaflow import FlowSpec, step, card, Parameter, current from metaflow.cards import Markdown +import requests, pandas, string + +URL = "https://upload.wikimedia.org/wikipedia/commons/4/45/Blue_Marble_rotating.gif" -class SimpleCardFlow(FlowSpec): +class SimpleCardFlow(FlowSpec): number = Parameter("number", default=3) + image_url = Parameter("image_url", default=URL) @card(type="blank") @step @@ -16,6 +20,21 @@ def start(self): current.card.append(Markdown("My number is **larger** ⬆️")) else: current.card.append(Markdown("## Correct! 🎉")) + + self.next(self.a) + + @step + def a(self): + print(f"image: {self.image_url}") + self.image = requests.get( + self.image_url, headers={"user-agent": "metaflow-example"} + ).content + self.dataframe = pandas.DataFrame( + { + "lowercase": list(string.ascii_lowercase), + "uppercase": list(string.ascii_uppercase), + } + ) self.next(self.end) @step diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 9f92e24ad09..04e9a852c1e 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -127,3 +127,39 @@ def test_card_flow(simple_card_run): res = get_cards(spin_task, follow_resumed=False) print(res) + +def test_inspect_spin_client_access(simple_parameter_run): + """Test accessing spin artifacts using inspect_spin client directly.""" + from metaflow import inspect_spin, Task + import tempfile + + step_name = "start" + task = simple_parameter_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + + with tempfile.TemporaryDirectory() as tmpdir: + # Run spin to generate artifacts + with Runner(flow_path, cwd=FLOWS_DIR).spin( + task.pathspec, + ) as spin: + spin_task = spin.task + spin_pathspec = spin_task.pathspec + assert spin_task['a'] is not None + assert spin_task['b'] is not None + + assert spin_pathspec is not None + + # Set metadata provider to spin + inspect_spin(FLOWS_DIR) + client_task = Task(spin_pathspec, _namespace_check=False) + + # Verify task is accessible + assert client_task is not None + + # Verify artifacts + assert hasattr(client_task, 'artifacts') + + # Verify artifact data + assert client_task.artifacts.a.data == 10 + assert client_task.artifacts.b.data == 20 + assert client_task.artifacts.alpha.data == 0.05 \ No newline at end of file From 9f31540141502447dadae37c195ef71e01ef2ba9 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 1 Oct 2025 18:50:15 -0700 Subject: [PATCH 08/21] Address minor nits --- metaflow/cli_components/run_cmds.py | 2 +- metaflow/cli_components/step_cmd.py | 7 +++---- metaflow/datastore/flow_datastore.py | 9 ++++++--- metaflow/datastore/task_datastore.py | 5 +++-- metaflow/runner/metaflow_runner.py | 21 +++++++++------------ metaflow/task.py | 2 +- metaflow/util.py | 4 +++- 7 files changed, 26 insertions(+), 24 deletions(-) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 364ddd80b8b..af4e1f2f234 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -451,7 +451,7 @@ def run( default=True, show_default=True, help="Whether to persist the artifacts in the spun step. If set to False, " - "the artifacts will notbe persisted and will not be available in the spun step's " + "the artifacts will not be persisted and will not be available in the spun step's " "datastore.", ) @click.option( diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 41a9a67674f..79cc78ad584 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -122,7 +122,7 @@ def step( if ubf_context == "none": ubf_context = None if opt_namespace is not None: - namespace(opt_namespace or None) + namespace(opt_namespace) func = None try: @@ -198,7 +198,6 @@ def step( ) @click.option( "--orig-flow-datastore", - default=None, show_default=True, help="Original datastore for the flow from which a task is being spun", ) @@ -257,9 +256,9 @@ def step( def spin_step( ctx, step_name, + orig_flow_datastore, run_id=None, task_id=None, - orig_flow_datastore=None, input_paths=None, split_index=None, retry_count=None, @@ -277,7 +276,7 @@ def spin_step( echo = echo_always if opt_namespace is not None: - namespace(opt_namespace or None) + namespace(opt_namespace) input_paths = decompress_list(input_paths) if input_paths else [] diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 43cd6e8bc14..4e1a73657c5 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -1,5 +1,6 @@ import itertools import json +from abc import ABC, abstractmethod from .. import metaflow_config @@ -376,9 +377,11 @@ def load_data(self, keys, force_raw=False): yield key, blob -class MetadataCache(object): +class MetadataCache(ABC): + @abstractmethod def load_metadata(self, run_id, step_name, task_id, attempt): - pass + raise NotImplementedError() + @abstractmethod def store_metadata(self, run_id, step_name, task_id, attempt, metadata_dict): - pass + raise NotImplementedError() diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 11cf3904f98..0d846e7c88c 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -6,6 +6,7 @@ from functools import wraps from io import BufferedIOBase, FileIO, RawIOBase +from typing import List from types import MethodType, FunctionType from .. import metaflow_config @@ -259,7 +260,7 @@ def init_task(self): @only_if_not_done @require_mode("w") - def transfer_artifacts(self, other_datastore, names=None): + def transfer_artifacts(self, other_datastore : "TaskDataStore", names : List[str] =None): """ Copies the blobs from other_datastore to this datastore if the datastore roots are different. @@ -271,7 +272,7 @@ def transfer_artifacts(self, other_datastore, names=None): ---------- other_datastore : TaskDataStore Other datastore from which to copy artifacts from - names : List[string], optional, default None + names : List[str], optional, default None If provided, only transfer the artifacts with these names. If None, transfer all artifacts from the other datastore. """ diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 240f124e0b9..ba663edae68 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -6,7 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple -from metaflow import Run +from metaflow import Run, Task from metaflow.metaflow_config import CLICK_API_PROCESS_CONFIG @@ -50,7 +50,7 @@ class ExecutingProcess(object): def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but - instead user Runner.run() + instead use Runner.run() Parameters ---------- @@ -64,7 +64,7 @@ def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: self.runner = runner self.command_obj = command_obj - def __enter__(self) -> "ExecutingRun": + def __enter__(self) -> "ExecutingProcess": return self def __exit__(self, exc_type, exc_value, traceback): @@ -72,7 +72,7 @@ def __exit__(self, exc_type, exc_value, traceback): async def wait( self, timeout: Optional[float] = None, stream: Optional[str] = None - ) -> "ExecutingRun": + ) -> "ExecutingProcess": """ Wait for this run to finish, optionally with a timeout and optionally streaming its output. @@ -91,7 +91,7 @@ async def wait( Returns ------- - ExecutingRun + ExecutingProcess This object, allowing you to chain calls. """ await self.command_obj.wait(timeout, stream) @@ -215,11 +215,11 @@ class ExecutingTask(ExecutingProcess): """ def __init__( - self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + self, runner: "Runner", command_obj: CommandManager, task_obj: Task ) -> None: """ Create a new ExecutingTask -- this should not be done by the user directly but - instead user Runner.spin() + instead use Runner.spin() Parameters ---------- runner : Runner @@ -254,7 +254,7 @@ def __init__( ) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but - instead user Runner.run() + instead use Runner.run() Parameters ---------- runner : Runner @@ -482,7 +482,6 @@ def __get_executing_task(self, attribute_file_fd, command_obj): # Set the correct metadata from the runner_attribute file corresponding to this run. metadata_for_flow = content.get("metadata") - from metaflow import Task task_object = Task( pathspec, _namespace_check=False, _current_metadata=metadata_for_flow @@ -499,14 +498,12 @@ async def __async_get_executing_task(self, attribute_file_fd, command_obj): # Set the correct metadata from the runner_attribute file corresponding to this run. metadata_for_flow = content.get("metadata") - from metaflow import Task - task_object = Task( pathspec, _namespace_check=False, _current_metadata=metadata_for_flow ) return ExecutingTask(self, command_obj, task_object) - def spin(self, pathspec, **kwargs): + def spin(self, pathspec, **kwargs) -> ExecutingTask: """ Blocking spin execution of the run. This method will wait until the spun run has completed execution. diff --git a/metaflow/task.py b/metaflow/task.py index 52722a1be58..e8160945149 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -633,7 +633,7 @@ def run_step( decorators = step_func.decorators if self.orig_flow_datastore: # We filter only the whitelisted decorators in case of spin step. - decorators = [ + decorators = [] if not whitelist_decorators else [ deco for deco in decorators if deco.name in whitelist_decorators ] from_start("MetaflowTask: decorators initialized") diff --git a/metaflow/util.py b/metaflow/util.py index 03b8ce916e6..db0388e64c8 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -606,9 +606,11 @@ def read_artifacts_module(file_path: str) -> Dict[str, Any]: If the file cannot be read or does not contain the ARTIFACTS variable. """ import importlib.util + import os try: - spec = importlib.util.spec_from_file_location("artifacts_module", file_path) + module_name = os.path.splitext(os.path.basename(file_path))[0] + spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) variables = vars(module) From ad4262a5bf58f85b64d362f941d7d66305ee0b7c Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 1 Oct 2025 18:53:08 -0700 Subject: [PATCH 09/21] Update ancestor task logic --- metaflow/client/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index b6872cd3044..cef8183e1b0 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1255,7 +1255,8 @@ def parent_task_pathspecs(self) -> Iterator[str]: # Get the parent steps steps = self._get_previous_steps(graph_info, step_name) node_type = graph_info["steps"][step_name]["type"] - current_path = metadata_dict.get("foreach-execution-path") + metadata_key = "foreach-execution-path" + current_path = metadata_dict.get(metadata_key) if len(steps) > 1: # Static join - use exact path matching @@ -1290,7 +1291,6 @@ def parent_task_pathspecs(self) -> Iterator[str]: target_depth = current_depth - 1 pattern = ",".join(current_path.split(",")[:target_depth]) - metadata_key = "foreach-execution-path" for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): yield pathspec @@ -1312,7 +1312,8 @@ def child_task_pathspecs(self) -> Iterator[str]: steps = graph_info["steps"][step_name]["next"] node_type = graph_info["steps"][step_name]["type"] - current_path = self.metadata_dict.get("foreach-execution-path") + metadata_key = "foreach-execution-path" + current_path = metadata_dict.get(metadata_key) if len(steps) > 1: # Static split - use exact path matching @@ -1354,7 +1355,6 @@ def child_task_pathspecs(self) -> Iterator[str]: target_depth = current_depth - 1 pattern = ",".join(current_path.split(",")[:target_depth]) - metadata_key = "foreach-execution-path" for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): yield pathspec From 0536e0296c7b77979afb0ca60a5c53ffa9b29a31 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 10 Oct 2025 08:49:36 -0700 Subject: [PATCH 10/21] Parameter fix --- metaflow/runtime.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index c1a577d37e2..116ec116c63 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -199,35 +199,6 @@ def __init__( self._max_log_size = max_log_size self._encoding = sys.stdout.encoding or "UTF-8" - # If no artifacts module is provided, create a temporary one with parameter values - if not self._artifacts_module and hasattr(flow, "_get_parameters"): - import tempfile - import os - - # Collect parameter values from the flow - param_artifacts = {} - for var, param in flow._get_parameters(): - if hasattr(flow, var): - value = getattr(flow, var) - # Only add if it's an actual value, not the Parameter object - if value is not None and not hasattr(value, "IS_PARAMETER"): - param_artifacts[var] = value - - # If we have parameter values, create a temp module - if param_artifacts: - with tempfile.NamedTemporaryFile( - mode="w", suffix=".py", delete=False - ) as f: - f.write( - "# Auto-generated artifacts module for spin step parameters\n" - ) - f.write("ARTIFACTS = {\n") - for key, value in param_artifacts.items(): - f.write(f" {repr(key)}: {repr(value)},\n") - f.write("}\n") - self._artifacts_module = f.name - self._temp_artifacts_file = f.name # Store for cleanup later - # Create a new run_id for the spin task self.run_id = self._metadata.new_run_id() for deco in self.whitelist_decorators: @@ -323,14 +294,6 @@ def execute(self): finally: for deco in self.whitelist_decorators: deco.runtime_finished(exception) - # Clean up temporary artifacts file if we created one - if hasattr(self, "_temp_artifacts_file"): - import os - - try: - os.unlink(self._temp_artifacts_file) - except: - pass def _launch_and_monitor_task(self): worker = Worker( From a46e1c2b58c5197d2045ee27cb2517c7c93b5890 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 16 Oct 2025 14:00:24 -0700 Subject: [PATCH 11/21] Address comments --- metaflow/cli.py | 14 ++++----- metaflow/cli_components/run_cmds.py | 19 +++++++++--- metaflow/cli_components/step_cmd.py | 4 +-- metaflow/client/filecache.py | 2 +- metaflow/datastore/content_addressed_store.py | 18 +++++------ metaflow/datastore/flow_datastore.py | 7 ++--- metaflow/datastore/task_datastore.py | 8 ++--- metaflow/metaflow_config.py | 14 +++++++++ metaflow/runtime.py | 8 +++++ test/unit/spin/spin_test_helpers.py | 2 +- test/unit/spin/test_spin.py | 31 ++++++++++++++++--- 11 files changed, 91 insertions(+), 36 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 63fb39c925a..338b49953f1 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -323,11 +323,11 @@ def version(obj): is_eager=True, ) @click.option( - "--spin-mode", - is_flag=True, - default=False, - help="Enable spin mode for metaflow cli commands. Setting this flag will result " - "in using spin metadata and spin datastore for executions" + "--mode", + type=click.Choice(["spin"]), + default=None, + help="Execution mode for metaflow CLI commands. Use 'spin' to enable " + "spin metadata and spin datastore for executions" ) @click.pass_context def start( @@ -346,7 +346,7 @@ def start( local_config_file=None, config=None, config_value=None, - spin_mode=False, + mode=None, **deco_options ): if quiet: @@ -379,7 +379,7 @@ def start( ctx.obj.check = functools.partial(_check, echo) ctx.obj.top_cli = cli ctx.obj.package_suffixes = package_suffixes.split(",") - ctx.obj.spin_mode = spin_mode + ctx.obj.spin_mode = (mode == "spin") ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == datastore][0] diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index af4e1f2f234..2540e0c3f16 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -8,7 +8,7 @@ from ..exception import CommandException from ..graph import FlowGraph from ..metaflow_current import current -from ..metaflow_config import DEFAULT_DECOSPECS, FEAT_ALWAYS_UPLOAD_CODE_PACKAGE +from ..metaflow_config import DEFAULT_DECOSPECS, FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, SPIN_PERSIST from ..metaflow_profile import from_start from ..package import MetaflowPackage from ..runtime import NativeRuntime, SpinRuntime @@ -425,13 +425,15 @@ def run( runtime.execute() -@parameters.add_custom_parameters(deploy_mode=True) +# @parameters.add_custom_parameters(deploy_mode=True) @click.command(help="Spins up a task for a given step from a previous run locally.") @tracing.cli("cli/spin") @click.argument("pathspec") @click.option( "--skip-decorators/--no-skip-decorators", is_flag=True, + # Default False matches the saved_args check in cli.py for spin steps - skip_decorators + # only becomes True when explicitly passed, otherwise decorators are applied by default default=False, show_default=True, help="Skip decorators attached to the step or flow.", @@ -448,7 +450,7 @@ def run( @click.option( "--persist/--no-persist", "persist", - default=True, + default=SPIN_PERSIST, show_default=True, help="Whether to persist the artifacts in the spun step. If set to False, " "the artifacts will not be persisted and will not be available in the spun step's " @@ -480,7 +482,16 @@ def spin( before_run(obj, [], [], skip_decorators) obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") - obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + # For spin, flow parameters come from the original run, but _set_constants + # requires them in kwargs. Use parameter defaults as placeholders - they'll be + # overwritten when the spin step loads artifacts from the original run. + flow_param_defaults = {} + for var, param in obj.flow._get_parameters(): + if not param.IS_CONFIG_PARAMETER: + default_value = param.kwargs.get("default") + # Use None for required parameters without defaults + flow_param_defaults[param.name.replace("-", "_").lower()] = default_value + obj.flow._set_constants(obj.graph, flow_param_defaults, obj.config_options) step_func = getattr(obj.flow, step_name, None) if step_func is None: raise CommandException( diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 79cc78ad584..24ca9d784a0 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -188,13 +188,13 @@ def step( "--run-id", default=None, required=True, - help="Run ID for the step that's about to be spun", + help="Original run ID for the step that will be spun", ) @click.option( "--task-id", default=None, required=True, - help="Task ID for the step that's about to be spun", + help="Original Task ID for the step that will be spun", ) @click.option( "--orig-flow-datastore", diff --git a/metaflow/client/filecache.py b/metaflow/client/filecache.py index 980b5f34cf0..bf397e91b0e 100644 --- a/metaflow/client/filecache.py +++ b/metaflow/client/filecache.py @@ -406,10 +406,10 @@ def _get_task_datastore( class TaskMetadataCache(MetadataCache): def __init__(self, filecache, ds_type, ds_root, flow_name): + self._filecache = filecache self._ds_type = ds_type self._ds_root = ds_root self._flow_name = flow_name - self._filecache = filecache def _path(self, run_id, step_name, task_id, attempt): if attempt is None: diff --git a/metaflow/datastore/content_addressed_store.py b/metaflow/datastore/content_addressed_store.py index a8f2e0e4805..195f6ef3ca2 100644 --- a/metaflow/datastore/content_addressed_store.py +++ b/metaflow/datastore/content_addressed_store.py @@ -38,7 +38,7 @@ def __init__(self, prefix, storage_impl): def set_blob_cache(self, blob_cache): self._blob_cache = blob_cache - def save_blobs(self, blob_iter, raw=False, len_hint=0, _is_transfer=False): + def save_blobs(self, blob_iter, raw=False, len_hint=0, is_transfer=False): """ Saves blobs of data to the datastore @@ -65,7 +65,7 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0, _is_transfer=False): Whether to save the bytes directly or process them, by default False len_hint : Hint of the number of blobs that will be produced by the iterator, by default 0 - _is_transfer : bool, default False + is_transfer : bool, default False If True, this indicates we are saving blobs directly from the output of another content addressed store's @@ -79,7 +79,7 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0, _is_transfer=False): def packing_iter(): for blob in blob_iter: - if _is_transfer: + if is_transfer: key, blob_data, meta = blob path = self._storage_impl.path_join(self._prefix, key[:2], key) # Transfer data is always raw/decompressed, so mark it as such @@ -117,7 +117,7 @@ def packing_iter(): self._storage_impl.save_bytes(packing_iter(), overwrite=True, len_hint=len_hint) return results - def load_blobs(self, keys, force_raw=False, _is_transfer=False): + def load_blobs(self, keys, force_raw=False, is_transfer=False): """ Mirror function of save_blobs @@ -132,15 +132,15 @@ def load_blobs(self, keys, force_raw=False, _is_transfer=False): Support for backward compatibility with previous datastores. If True, this will force the key to be loaded as is (raw). By default, False - _is_transfer : bool, default False + is_transfer : bool, default False If True, this indicates we are loading blobs to transfer them directly - to another datastore. We will, in this case, also transfer the metdata + to another datastore. We will, in this case, also transfer the metadata and do minimal processing. This is for internal use only. Returns ------- Returns an iterator of (string, bytes) tuples; the iterator may return keys - in a different order than were passed in. If _is_transfer is True, the tuple + in a different order than were passed in. If is_transfer is True, the tuple has three elements with the third one being the metadata. """ load_paths = [] @@ -149,7 +149,7 @@ def load_blobs(self, keys, force_raw=False, _is_transfer=False): if self._blob_cache: blob = self._blob_cache.load_key(key) if blob is not None: - if _is_transfer: + if is_transfer: # Cached blobs are decompressed/processed bytes regardless of original format yield key, blob, {"cas_raw": False, "cas_version": 1} else: @@ -195,7 +195,7 @@ def load_blobs(self, keys, force_raw=False, _is_transfer=False): if self._blob_cache: self._blob_cache.store_key(key, blob) - if _is_transfer: + if is_transfer: yield key, blob, meta # Preserve exact original metadata from storage else: yield key, blob diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 4e1a73657c5..1352812a253 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -118,14 +118,14 @@ def get_task_datastores( If True, returns all attempts up to and including attempt. mode : str, default "r" Mode to initialize the returned TaskDataStores in. - join_type : str, optional + join_type : str, optional, default None If specified, the join type for the task. This is used to determine the user specified artifacts for the task in case of a spin task. - orig_flow_datastore : MetadataProvider, optional + orig_flow_datastore : MetadataProvider, optional, default None The metadata provider in case of a spin task. If provided, the returned TaskDataStore will be a SpinTaskDatastore instead of a TaskDataStore. - spin_artifacts : Dict[str, Any], optional + spin_artifacts : Dict[str, Any], optional, default None Artifacts provided by user that can override the artifacts fetched via the spin pathspec. @@ -277,7 +277,6 @@ def get_task_datastore( data_metadata=data_metadata, mode=mode, allow_not_done=allow_not_done, - join_type=join_type, persist=persist, ) diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 0d846e7c88c..0bee452079e 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -6,7 +6,7 @@ from functools import wraps from io import BufferedIOBase, FileIO, RawIOBase -from typing import List +from typing import List, Optional from types import MethodType, FunctionType from .. import metaflow_config @@ -260,7 +260,7 @@ def init_task(self): @only_if_not_done @require_mode("w") - def transfer_artifacts(self, other_datastore : "TaskDataStore", names : List[str] =None): + def transfer_artifacts(self, other_datastore : "TaskDataStore", names : Optional[List[str]] = None): """ Copies the blobs from other_datastore to this datastore if the datastore roots are different. @@ -316,11 +316,11 @@ def transfer_artifacts(self, other_datastore : "TaskDataStore", names : List[str # Load blobs from other datastore in transfer mode transfer_blobs = other_datastore._ca_store.load_blobs( - missing_shas, _is_transfer=True + missing_shas, is_transfer=True ) # Save blobs to local datastore in transfer mode - self._ca_store.save_blobs(transfer_blobs, _is_transfer=True) + self._ca_store.save_blobs(transfer_blobs, is_transfer=True) @only_if_not_done @require_mode("w") diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 21aa0b40acd..38b8c20e7b7 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -51,6 +51,7 @@ ### # Spin configuration ### +# Essentially a whitelist of decorators that are allowed in Spin steps SPIN_ALLOWED_DECORATORS = from_conf( "SPIN_ALLOWED_DECORATORS", [ @@ -66,6 +67,19 @@ ], ) +# Essentially a blacklist of decorators that are not allowed in Spin steps +# Note: decorators not in either SPIN_ALLOWED_DECORATORS or SPIN_DISALLOWED_DECORATORS +# are simply ignored in Spin steps +SPIN_DISALLOWED_DECORATORS = from_conf( + "SPIN_DISALLOWED_DECORATORS", + [ + "parallel", + ], +) + +# Default value for persist option in spin command +SPIN_PERSIST = from_conf("SPIN_PERSIST", False) + ### # User configuration ### diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 116ec116c63..3e30bbd694b 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -33,6 +33,7 @@ MAX_ATTEMPTS, UI_URL, SPIN_ALLOWED_DECORATORS, + SPIN_DISALLOWED_DECORATORS, ) from .metaflow_profile import from_start from .plugins import DATASTORES @@ -201,6 +202,13 @@ def __init__( # Create a new run_id for the spin task self.run_id = self._metadata.new_run_id() + # Raise exception if we have a black listed decorator + for deco in self._step_func.decorators: + if deco.name in SPIN_DISALLOWED_DECORATORS: + raise MetaflowException( + f"Spinning steps with @{deco.name} decorator is not supported." + ) + for deco in self.whitelist_decorators: deco.runtime_init(flow, graph, package, self.run_id) from_start("SpinRuntime: after init decorators") diff --git a/test/unit/spin/spin_test_helpers.py b/test/unit/spin/spin_test_helpers.py index b2c61e37457..4952aad364f 100644 --- a/test/unit/spin/spin_test_helpers.py +++ b/test/unit/spin/spin_test_helpers.py @@ -26,7 +26,7 @@ def run_step(flow_file, run, step_name, **tl_kwargs): flow_path = os.path.join(FLOWS_DIR, flow_file) print(f"FLOWS_DIR: {FLOWS_DIR}") - with Runner(flow_path, cwd=FLOWS_DIR, **tl_kwargs).spin(task.pathspec) as spin: + with Runner(flow_path, cwd=FLOWS_DIR, **tl_kwargs).spin(task.pathspec, persist=True) as spin: print("-" * 50) print(f"Running test for step: {step_name} with task pathspec: {task.pathspec}") assert_artifacts(task, spin.task) diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 04e9a852c1e..82187cd6b6e 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -36,6 +36,7 @@ def test_artifacts_module(complex_dag_run): with Runner(flow_path, environment="conda").spin( task.pathspec, artifacts_module=artifacts_path, + persist=True, ) as spin: print("-" * 50) print(f"Running test for step: step_a with task pathspec: {task.pathspec}") @@ -59,6 +60,7 @@ def test_artifacts_module_join_step( with Runner(flow_path, environment="conda").spin( task.pathspec, artifacts_module=str(temp_artifacts_file), + persist=True, ) as spin: print("-" * 50) print(f"Running test for step: step_d with task pathspec: {task.pathspec}") @@ -78,6 +80,7 @@ def test_timeout_decorator_enforcement(simple_config_run): flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] ).spin( task.pathspec, + persist=True, ): pass @@ -94,6 +97,7 @@ def test_skip_decorators_bypass(simple_config_run): ).spin( task.pathspec, skip_decorators=True, + persist=True, ) as spin: print(f"Running test for step: {step_name} with skip_decorators=True") # Should complete successfully even though sleep(5) > timeout(2) @@ -108,7 +112,7 @@ def test_hidden_artifacts(simple_parameter_run): flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") print(f"Running test for hidden artifacts in {flow_path}: {simple_parameter_run}") - with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec) as spin: + with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec, persist=True) as spin: spin_task = spin.task assert "_graph_info" in spin_task assert "_foreach_stack" in spin_task @@ -121,13 +125,31 @@ def test_card_flow(simple_card_run): flow_path = os.path.join(FLOWS_DIR, "simple_card_flow.py") print(f"Running test for cards in {flow_path}: {simple_card_run}") - with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec) as spin: + with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec, persist=True) as spin: spin_task = spin.task from metaflow.cards import get_cards res = get_cards(spin_task, follow_resumed=False) print(res) + +def test_spin_with_parameters_raises_error(simple_parameter_run): + """Test that passing flow parameters to spin raises an error.""" + step_name = "start" + task = simple_parameter_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + + with pytest.raises(Exception, match="Unknown argument"): + with Runner(flow_path, cwd=FLOWS_DIR).spin( + task.pathspec, + alpha=1.0, + persist=True, + ): + pass + + +# NOTE: This test has to be the last test because it modifies the metadata +# provider when calling inspect_spin def test_inspect_spin_client_access(simple_parameter_run): """Test accessing spin artifacts using inspect_spin client directly.""" from metaflow import inspect_spin, Task @@ -137,10 +159,11 @@ def test_inspect_spin_client_access(simple_parameter_run): task = simple_parameter_run[step_name].task flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") - with tempfile.TemporaryDirectory() as tmpdir: + with tempfile.TemporaryDirectory() as _: # Run spin to generate artifacts with Runner(flow_path, cwd=FLOWS_DIR).spin( task.pathspec, + persist=True, ) as spin: spin_task = spin.task spin_pathspec = spin_task.pathspec @@ -162,4 +185,4 @@ def test_inspect_spin_client_access(simple_parameter_run): # Verify artifact data assert client_task.artifacts.a.data == 10 assert client_task.artifacts.b.data == 20 - assert client_task.artifacts.alpha.data == 0.05 \ No newline at end of file + assert client_task.artifacts.alpha.data == 0.05 From 387b6ba53313019ea6016f10b4251eedf6c9c288 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 16 Oct 2025 17:05:32 -0700 Subject: [PATCH 12/21] Raise an exception if attempt is None in TaskMetadataCache --- metaflow/client/filecache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/metaflow/client/filecache.py b/metaflow/client/filecache.py index bf397e91b0e..35d96c5e72c 100644 --- a/metaflow/client/filecache.py +++ b/metaflow/client/filecache.py @@ -413,7 +413,10 @@ def __init__(self, filecache, ds_type, ds_root, flow_name): def _path(self, run_id, step_name, task_id, attempt): if attempt is None: - return None + raise MetaflowException( + "Attempt number must be specified to use task metadata cache. Raise an issue " + "on Metaflow GitHub if you see this message.", + ) cache_id = self._filecache.task_ds_id( self._ds_type, self._ds_root, From be458a04a926686073469ce2bbeae743cc54de01 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 20 Oct 2025 13:44:01 -0700 Subject: [PATCH 13/21] Use default user namespace for spin pathspecs --- metaflow/runtime.py | 1 + metaflow/util.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 3e30bbd694b..3dfb01f529d 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -136,6 +136,7 @@ def __init__( if len(parts) == 4: # Complete pathspec: flow/run/step/task_id try: + # If user provides whole pathspec, we do not need to check namespace task = Task(spin_pathspec, _namespace_check=False) except Exception: raise MetaflowException( diff --git a/metaflow/util.py b/metaflow/util.py index db0388e64c8..3e6fcbb8db7 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -278,14 +278,14 @@ def get_latest_task_pathspec(flow_name: str, step_name: str, run_id: str = None) from metaflow.exception import MetaflowNotFound if not run_id: - flow = Flow(flow_name, _namespace_check=False) + flow = Flow(flow_name) run = flow.latest_run if run is None: raise MetaflowNotFound(f"No run found for flow {flow_name}") run_id = run.id try: - task = Step(f"{flow_name}/{run_id}/{step_name}", _namespace_check=False).task + task = Step(f"{flow_name}/{run_id}/{step_name}").task return task except: raise MetaflowNotFound(f"No task found for step {step_name} in run {run_id}") From 72839a622f5f45f12137d51152d39f6377e3eb43 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 21 Oct 2025 11:05:54 -0700 Subject: [PATCH 14/21] Fix for runtimedag parallel tests --- metaflow/client/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index cef8183e1b0..9537c3bcd7e 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1281,13 +1281,13 @@ def parent_task_pathspecs(self) -> Iterator[str]: # Pattern: "A:10,B:13" parent_step_type = graph_info["steps"][steps[0]]["type"] target_depth = current_depth - if parent_step_type == "split-foreach" and current_depth == 1: + if (parent_step_type == "split-foreach" or parent_step_type == "split-parallel") and current_depth == 1: # (Current task, "A:10") and (Parent task, "") pattern = ".*" else: # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") - if parent_step_type == "split-foreach": + if parent_step_type == "split-foreach" or parent_step_type == "split-parallel": target_depth = current_depth - 1 pattern = ",".join(current_path.split(",")[:target_depth]) @@ -1328,7 +1328,7 @@ def child_task_pathspecs(self) -> Iterator[str]: pattern = ".*" else: current_depth = len(current_path.split(",")) - if node_type == "split-foreach": + if node_type == "split-foreach" or node_type == "split-parallel": # Foreach split # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") # Pattern: "A:10,B:13,.*" From 84f2eb38719d137971f64d3768a736a649d3a008 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 21 Oct 2025 13:02:57 -0700 Subject: [PATCH 15/21] Update typing imports --- metaflow/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/util.py b/metaflow/util.py index 3e6fcbb8db7..a0abc9409fd 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -9,7 +9,7 @@ from functools import wraps from io import BytesIO from itertools import takewhile -from typing import Dict, Any, Tuple, Optional, List +from typing import Dict, Any, Tuple, Optional, List, Generator try: From eb543ccc9e29f588b8d8e57db76cd0cdad287437 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 21 Oct 2025 13:10:01 -0700 Subject: [PATCH 16/21] Run black formatting --- metaflow/cli.py | 11 ++++++++--- metaflow/cli_components/run_cmds.py | 6 +++++- metaflow/client/core.py | 10 ++++++++-- metaflow/datastore/task_datastore.py | 4 +++- metaflow/plugins/cards/card_cli.py | 2 ++ metaflow/plugins/cards/card_modules/basic.py | 4 +++- metaflow/task.py | 8 +++++--- metaflow/util.py | 4 +++- test/unit/spin/conftest.py | 3 ++- test/unit/spin/flows/hello_spin_flow.py | 4 ++-- test/unit/spin/spin_test_helpers.py | 4 +++- test/unit/spin/test_spin.py | 6 +++--- 12 files changed, 47 insertions(+), 19 deletions(-) diff --git a/metaflow/cli.py b/metaflow/cli.py index 338b49953f1..a4a1558304d 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -327,7 +327,7 @@ def version(obj): type=click.Choice(["spin"]), default=None, help="Execution mode for metaflow CLI commands. Use 'spin' to enable " - "spin metadata and spin datastore for executions" + "spin metadata and spin datastore for executions", ) @click.pass_context def start( @@ -379,7 +379,7 @@ def start( ctx.obj.check = functools.partial(_check, echo) ctx.obj.top_cli = cli ctx.obj.package_suffixes = package_suffixes.split(",") - ctx.obj.spin_mode = (mode == "spin") + ctx.obj.spin_mode = mode == "spin" ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == datastore][0] @@ -509,7 +509,12 @@ def start( ctx.obj.skip_decorators = False # Override values for spin steps, or if we are in spin mode - if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0] or ctx.obj.spin_mode: + if ( + hasattr(ctx, "saved_args") + and ctx.saved_args + and "spin" in ctx.saved_args[0] + or ctx.obj.spin_mode + ): # To minimize side effects for spin, we will only use the following: # - local metadata provider, # - local datastore, diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 2540e0c3f16..578e6b666a8 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -8,7 +8,11 @@ from ..exception import CommandException from ..graph import FlowGraph from ..metaflow_current import current -from ..metaflow_config import DEFAULT_DECOSPECS, FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, SPIN_PERSIST +from ..metaflow_config import ( + DEFAULT_DECOSPECS, + FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, + SPIN_PERSIST, +) from ..metaflow_profile import from_start from ..package import MetaflowPackage from ..runtime import NativeRuntime, SpinRuntime diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 9537c3bcd7e..70bcaf4c08a 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1281,13 +1281,19 @@ def parent_task_pathspecs(self) -> Iterator[str]: # Pattern: "A:10,B:13" parent_step_type = graph_info["steps"][steps[0]]["type"] target_depth = current_depth - if (parent_step_type == "split-foreach" or parent_step_type == "split-parallel") and current_depth == 1: + if ( + parent_step_type == "split-foreach" + or parent_step_type == "split-parallel" + ) and current_depth == 1: # (Current task, "A:10") and (Parent task, "") pattern = ".*" else: # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") - if parent_step_type == "split-foreach" or parent_step_type == "split-parallel": + if ( + parent_step_type == "split-foreach" + or parent_step_type == "split-parallel" + ): target_depth = current_depth - 1 pattern = ",".join(current_path.split(",")[:target_depth]) diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 0bee452079e..abeeb8ea5fb 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -260,7 +260,9 @@ def init_task(self): @only_if_not_done @require_mode("w") - def transfer_artifacts(self, other_datastore : "TaskDataStore", names : Optional[List[str]] = None): + def transfer_artifacts( + self, other_datastore: "TaskDataStore", names: Optional[List[str]] = None + ): """ Copies the blobs from other_datastore to this datastore if the datastore roots are different. diff --git a/metaflow/plugins/cards/card_cli.py b/metaflow/plugins/cards/card_cli.py index 766f23a70d8..9cb8b4bbb9d 100644 --- a/metaflow/plugins/cards/card_cli.py +++ b/metaflow/plugins/cards/card_cli.py @@ -335,6 +335,7 @@ def list_many_cards( def cli(): pass + @cli.group(help="Commands related to @card decorator.") @click.pass_context def card(ctx): @@ -342,6 +343,7 @@ def card(ctx): # Can work with the Metaflow client. # If we don't set the metadata here than the metaflow client picks the defaults when calling the `Task`/`Run` objects. These defaults can come from the `config.json` file or based on the `METAFLOW_PROFILE` from metaflow import metadata + setting_metadata = "@".join( [ctx.obj.metadata.TYPE, ctx.obj.metadata.default_info()] ) diff --git a/metaflow/plugins/cards/card_modules/basic.py b/metaflow/plugins/cards/card_modules/basic.py index 1081a2b916b..f578a9e03fc 100644 --- a/metaflow/plugins/cards/card_modules/basic.py +++ b/metaflow/plugins/cards/card_modules/basic.py @@ -501,7 +501,9 @@ def render(self): param_ids = [] else: param_ids = [ - p.id for p in self._task.parent.parent["_parameters"].task if p.id != "name" + p.id + for p in self._task.parent.parent["_parameters"].task + if p.id != "name" ] if len(param_ids) > 0: # Extract parameter from the Parameter Task. That is less brittle. diff --git a/metaflow/task.py b/metaflow/task.py index e8160945149..71c6e54aa6d 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -633,9 +633,11 @@ def run_step( decorators = step_func.decorators if self.orig_flow_datastore: # We filter only the whitelisted decorators in case of spin step. - decorators = [] if not whitelist_decorators else [ - deco for deco in decorators if deco.name in whitelist_decorators - ] + decorators = ( + [] + if not whitelist_decorators + else [deco for deco in decorators if deco.name in whitelist_decorators] + ) from_start("MetaflowTask: decorators initialized") node = self.flow._graph[step_name] join_type = None diff --git a/metaflow/util.py b/metaflow/util.py index a0abc9409fd..82e29ef23f1 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -250,7 +250,9 @@ def parse_spin_pathspec(pathspec: str, flow_name: str) -> Tuple: return step_name, parsed_pathspec -def get_latest_task_pathspec(flow_name: str, step_name: str, run_id: str = None) -> "metaflow.Task": +def get_latest_task_pathspec( + flow_name: str, step_name: str, run_id: str = None +) -> "metaflow.Task": """ Returns a task pathspec from the latest run (or specified run) of the flow for the queried step. If the queried step has several tasks, the task pathspec of the first task is returned. diff --git a/test/unit/spin/conftest.py b/test/unit/spin/conftest.py index a4f90937543..2c87de284b1 100644 --- a/test/unit/spin/conftest.py +++ b/test/unit/spin/conftest.py @@ -69,7 +69,8 @@ def flow_fixture(request): simple_card_run = pytest.fixture(scope="session")( create_flow_fixture( - "SimpleCardFlow", "simple_card_flow.py", + "SimpleCardFlow", + "simple_card_flow.py", ) ) diff --git a/test/unit/spin/flows/hello_spin_flow.py b/test/unit/spin/flows/hello_spin_flow.py index 2df4a6aeee6..d6c537b75b0 100644 --- a/test/unit/spin/flows/hello_spin_flow.py +++ b/test/unit/spin/flows/hello_spin_flow.py @@ -7,7 +7,7 @@ class HelloSpinFlow(FlowSpec): @step def start(self): chunk_size = 1024 * 1024 # 1 MB - total_size = 1024 * 1024 * 1000 # 1000 MB + total_size = 1024 * 1024 * 1000 # 1000 MB data = bytearray() for _ in range(total_size // chunk_size): @@ -23,4 +23,4 @@ def end(self): if __name__ == "__main__": - HelloSpinFlow() \ No newline at end of file + HelloSpinFlow() diff --git a/test/unit/spin/spin_test_helpers.py b/test/unit/spin/spin_test_helpers.py index 4952aad364f..afc9942edec 100644 --- a/test/unit/spin/spin_test_helpers.py +++ b/test/unit/spin/spin_test_helpers.py @@ -26,7 +26,9 @@ def run_step(flow_file, run, step_name, **tl_kwargs): flow_path = os.path.join(FLOWS_DIR, flow_file) print(f"FLOWS_DIR: {FLOWS_DIR}") - with Runner(flow_path, cwd=FLOWS_DIR, **tl_kwargs).spin(task.pathspec, persist=True) as spin: + with Runner(flow_path, cwd=FLOWS_DIR, **tl_kwargs).spin( + task.pathspec, persist=True + ) as spin: print("-" * 50) print(f"Running test for step: {step_name} with task pathspec: {task.pathspec}") assert_artifacts(task, spin.task) diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 82187cd6b6e..bea53813162 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -167,8 +167,8 @@ def test_inspect_spin_client_access(simple_parameter_run): ) as spin: spin_task = spin.task spin_pathspec = spin_task.pathspec - assert spin_task['a'] is not None - assert spin_task['b'] is not None + assert spin_task["a"] is not None + assert spin_task["b"] is not None assert spin_pathspec is not None @@ -180,7 +180,7 @@ def test_inspect_spin_client_access(simple_parameter_run): assert client_task is not None # Verify artifacts - assert hasattr(client_task, 'artifacts') + assert hasattr(client_task, "artifacts") # Verify artifact data assert client_task.artifacts.a.data == 10 From afea14ffdf4a5389015b664dd6ee0a041e4173dc Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 21 Oct 2025 13:57:23 -0700 Subject: [PATCH 17/21] Update docstrings --- metaflow/datastore/content_addressed_store.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/metaflow/datastore/content_addressed_store.py b/metaflow/datastore/content_addressed_store.py index 195f6ef3ca2..9b24a482a89 100644 --- a/metaflow/datastore/content_addressed_store.py +++ b/metaflow/datastore/content_addressed_store.py @@ -60,10 +60,12 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0, is_transfer=False): Parameters ---------- - blob_iter : Iterator over bytes objects to save - raw : bool, optional + blob_iter : Iterator + Iterator over bytes objects to save + raw : bool, default False Whether to save the bytes directly or process them, by default False - len_hint : Hint of the number of blobs that will be produced by the + len_hint : int, default 0 + Hint of the number of blobs that will be produced by the iterator, by default 0 is_transfer : bool, default False If True, this indicates we are saving blobs directly from the output of another From e15c14d6550eacff4f5cc68625a2dfc248609ae9 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Tue, 21 Oct 2025 23:39:30 -0700 Subject: [PATCH 18/21] Fix typo Also drive-by improvement of some comments --- metaflow/cli_components/run_cmds.py | 2 +- metaflow/decorators.py | 4 ++-- metaflow/flowspec.py | 5 +++-- metaflow/user_decorators/user_step_decorator.py | 11 ++++++++++- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 578e6b666a8..82272b70568 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -56,7 +56,7 @@ def before_run(obj, tags, decospecs, skip_decorators=False): decorators._attach_decorators(obj.flow, all_decospecs) decorators._init(obj.flow) # Regenerate graph if we attached more decorators - obj.flow.__class__._init_attrs() + obj.flow.__class__._init_graph() obj.graph = obj.flow._graph obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) diff --git a/metaflow/decorators.py b/metaflow/decorators.py index d6f91cd3066..760508497f0 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -806,7 +806,7 @@ def _init_step_decorators( "expected %s but got %s" % (deco._flow_cls.__name__, cls.__name__) ) debug.userconf_exec( - "Evaluating flow level decorator %s (post)" % deco.__class__.__name__ + "Evaluating flow level decorator %s (mutate)" % deco.__class__.__name__ ) deco.mutate(mutable_flow) # We reset cached_parameters on the very off chance that the user added @@ -824,7 +824,7 @@ def _init_step_decorators( if isinstance(deco, StepMutator): debug.userconf_exec( - "Evaluating step level decorator %s (post) for %s" + "Evaluating step level decorator %s for %s (mutate)" % (deco.__class__.__name__, step.name) ) deco.mutate( diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index 3a34a5f680e..a8df867e644 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -307,7 +307,8 @@ def _process_config_decorators(cls, config_options, process_configs=True): % (deco._flow_cls.__name__, cls.__name__) ) debug.userconf_exec( - "Evaluating flow level decorator %s" % deco.__class__.__name__ + "Evaluating flow level decorator %s (pre-mutate)" + % deco.__class__.__name__ ) deco.pre_mutate(mutable_flow) # We reset cached_parameters on the very off chance that the user added @@ -324,7 +325,7 @@ def _process_config_decorators(cls, config_options, process_configs=True): if isinstance(deco, StepMutator): inserted_by_value = [deco.decorator_name] + (deco.inserted_by or []) debug.userconf_exec( - "Evaluating step level decorator %s for %s" + "Evaluating step level decorator %s for %s (pre-mutate)" % (deco.__class__.__name__, step.name) ) deco.pre_mutate( diff --git a/metaflow/user_decorators/user_step_decorator.py b/metaflow/user_decorators/user_step_decorator.py index 0017961009d..bf7318c6589 100644 --- a/metaflow/user_decorators/user_step_decorator.py +++ b/metaflow/user_decorators/user_step_decorator.py @@ -542,7 +542,7 @@ def user_step_decorator(*args, **kwargs): ``` @user_step_decorator - def timing(step_name, flow, inputs): + def timing(step_name, flow, inputs, attributes): start_time = time.time() yield end_time = time.time() @@ -559,6 +559,15 @@ def start(self): ``` Your generator should: + - take 3 or 4 arguments: step_name, flow, inputs, and attributes (optional) + - step_name: the name of the step + - flow: the flow object + - inputs: the inputs to the step + - attributes: the kwargs passed in when initializing the decorator. In the + example above, something like `@timing(arg1="foo", arg2=42)` would make + `attributes = {"arg1": "foo", "arg2": 42}`. If you choose to pass arguments + to the decorator when you apply it to the step, your function *must* take + 4 arguments (step_name, flow, inputs, attributes). - yield at most once -- if you do not yield, the step will not execute. - yield: - None From f8c999703002ce106838fc3cebba942a62dba9c2 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 22 Oct 2025 07:05:32 -0700 Subject: [PATCH 19/21] Retrigger tests --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index cd355c268c6..9349613e8cb 100644 --- a/README.md +++ b/README.md @@ -60,5 +60,3 @@ We'd love to hear from you. Join our community [Slack workspace](http://slack.ou ## Contributing We welcome contributions to Metaflow. Please see our [contribution guide](https://docs.metaflow.org/introduction/contributing-to-metaflow) for more details. - - From 7f2abe34663b06a453bd0ea8d77aa76e87cb140a Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 22 Oct 2025 07:12:47 -0700 Subject: [PATCH 20/21] Set default root for inspect spin method to be cwd --- metaflow/client/core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 70bcaf4c08a..686cc526738 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -207,10 +207,15 @@ def default_namespace() -> str: return get_namespace() -def inspect_spin(datastore_root): +def inspect_spin(datastore_root : str = "."): """ Set metadata provider to spin metadata so that users can inspect spin steps, tasks, and artifacts. + + Parameters + ---------- + datastore_root : str, default "." + The root path to the spin datastore. """ metadata_str = f"spin@{datastore_root}" metadata(metadata_str) From 3f492dc9b038996bfd5bab5bf6f5b2fd9afb8882 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 22 Oct 2025 07:49:52 -0700 Subject: [PATCH 21/21] Apply black formatting --- metaflow/client/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 686cc526738..e6852b71901 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -207,7 +207,7 @@ def default_namespace() -> str: return get_namespace() -def inspect_spin(datastore_root : str = "."): +def inspect_spin(datastore_root: str = "."): """ Set metadata provider to spin metadata so that users can inspect spin steps, tasks, and artifacts.