Skip to content

Commit 597b9ce

Browse files
committed
Cleanup
1 parent d745345 commit 597b9ce

File tree

16 files changed

+209
-191
lines changed

16 files changed

+209
-191
lines changed

src/zenml/config/compiler.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from zenml.exceptions import StackValidationError
4343
from zenml.models import PipelineSnapshotBase
4444
from zenml.pipelines.run_utils import get_default_run_name
45-
from zenml.steps.step_invocation import StepInvocation
4645
from zenml.utils import pydantic_utils, secret_utils, settings_utils
4746

4847
if TYPE_CHECKING:
@@ -127,12 +126,18 @@ def compile(
127126
merge=False,
128127
)
129128

129+
# If we're compiling a dynamic pipeline, the steps are only templates
130+
# and might not have all inputs defined, so we skip the input
131+
# validation.
132+
skip_input_validation = pipeline.is_dynamic
133+
130134
steps = {
131135
invocation_id: self._compile_step_invocation(
132136
invocation=invocation,
133137
stack=stack,
134138
step_config=(run_configuration.steps or {}).get(invocation_id),
135139
pipeline_configuration=pipeline.configuration,
140+
skip_input_validation=skip_input_validation,
136141
)
137142
for invocation_id, invocation in self._get_sorted_invocations(
138143
pipeline=pipeline
@@ -465,6 +470,7 @@ def _compile_step_invocation(
465470
stack: "Stack",
466471
step_config: Optional["StepConfigurationUpdate"],
467472
pipeline_configuration: "PipelineConfiguration",
473+
skip_input_validation: bool = False,
468474
) -> Step:
469475
"""Compiles a ZenML step.
470476
@@ -473,6 +479,7 @@ def _compile_step_invocation(
473479
stack: The stack on which the pipeline will be run.
474480
step_config: Run configuration for the step.
475481
pipeline_configuration: Configuration for the pipeline.
482+
skip_input_validation: If True, will skip the input validation.
476483
477484
Returns:
478485
The compiled step.
@@ -487,6 +494,10 @@ def _compile_step_invocation(
487494
step_config, runtime_parameters=invocation.parameters
488495
)
489496

497+
# Apply the dynamic configuration (which happened while executing the
498+
# pipeline function) after all other step-specific configurations.
499+
step._apply_dynamic_configuration()
500+
490501
convert_component_shortcut_settings_keys(
491502
step.configuration.settings, stack=stack
492503
)
@@ -509,7 +520,8 @@ def _compile_step_invocation(
509520
set(step_config.parameters or {}) if step_config else set()
510521
)
511522
step_configuration_overrides = invocation.finalize(
512-
parameters_to_ignore=parameters_to_ignore
523+
parameters_to_ignore=parameters_to_ignore,
524+
skip_input_validation=skip_input_validation,
513525
)
514526
full_step_config = (
515527
step_configuration_overrides.apply_pipeline_configuration(
@@ -535,9 +547,13 @@ def _get_sorted_invocations(
535547
pipeline: The pipeline of which to sort the invocations
536548
537549
Returns:
538-
The sorted steps.
550+
The sorted step invocations.
539551
"""
540552
if pipeline.is_dynamic:
553+
# In dynamic pipelines, we require the static invocations to be
554+
# sorted the same way they were passed in `pipeline.depends_on`, as
555+
# we index this list later to figure out the correct template for
556+
# each step invocation.
541557
return list(pipeline.invocations.items())
542558

543559
from zenml.orchestrators.dag_runner import reverse_dag

src/zenml/config/step_configurations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ class PartialStepConfiguration(StepConfigurationUpdate):
261261
"""Class representing a partial step configuration."""
262262

263263
name: str
264+
# TODO: maybe move to spec?
264265
template: Optional[str] = None
265266
parameters: Dict[str, Any] = {}
266267
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}

src/zenml/models/v2/core/pipeline_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ class PipelineRunUpdate(BaseUpdate):
178178
max_length=STR_FIELD_MAX_LENGTH,
179179
)
180180
end_time: Optional[datetime] = None
181-
completed: Optional[bool] = Field(
181+
finished: Optional[bool] = Field(
182182
default=None,
183-
title="Whether the pipeline run is completed.",
183+
title="Whether the pipeline run is finished.",
184184
)
185185
orchestrator_run_id: Optional[str] = None
186186
# TODO: we should maybe have a different update model here, the upper

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 41 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# permissions and limitations under the License.
1414
"""Base orchestrator class."""
1515

16-
import time
1716
from abc import ABC, abstractmethod
1817
from typing import (
1918
TYPE_CHECKING,
@@ -40,7 +39,6 @@
4039
HookExecutionException,
4140
IllegalOperationError,
4241
RunMonitoringError,
43-
RunStoppedException,
4442
)
4543
from zenml.hooks.hook_validators import load_and_run_hook
4644
from zenml.logger import get_logger
@@ -50,7 +48,6 @@
5048
publish_pipeline_run_status_update,
5149
publish_schedule_metadata,
5250
)
53-
from zenml.orchestrators.step_launcher import StepLauncher
5451
from zenml.orchestrators.utils import get_config_environment_vars
5552
from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig
5653
from zenml.steps.step_context import RunContext, get_or_create_run_context
@@ -274,16 +271,6 @@ def run(
274271
# in the orchestrator environment
275272
base_environment.update(secrets)
276273

277-
is_dynamic = True
278-
if is_dynamic:
279-
submission_result = self.submit_dynamic_pipeline(
280-
snapshot=snapshot,
281-
stack=stack,
282-
environment=base_environment,
283-
placeholder_run=placeholder_run,
284-
)
285-
return
286-
287274
prevent_client_side_caching = handle_bool_env_var(
288275
ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, default=False
289276
)
@@ -292,6 +279,7 @@ def run(
292279
placeholder_run
293280
and self.config.supports_client_side_caching
294281
and not snapshot.schedule
282+
and not snapshot.is_dynamic
295283
and not prevent_client_side_caching
296284
):
297285
from zenml.orchestrators import cache_utils
@@ -310,22 +298,10 @@ def run(
310298
else:
311299
logger.debug("Skipping client-side caching.")
312300

313-
step_environments = {}
314-
for invocation_id, step in snapshot.step_configurations.items():
315-
from zenml.utils.env_utils import get_step_environment
316-
317-
step_environment = get_step_environment(
318-
step_config=step.config,
319-
stack=stack,
320-
)
321-
322-
combined_environment = base_environment.copy()
323-
combined_environment.update(step_environment)
324-
step_environments[invocation_id] = combined_environment
325-
326301
try:
327302
if (
328-
getattr(self.submit_pipeline, "__func__", None)
303+
not snapshot.is_dynamic
304+
and getattr(self.submit_pipeline, "__func__", None)
329305
is BaseOrchestrator.submit_pipeline
330306
):
331307
logger.warning(
@@ -357,13 +333,37 @@ def run(
357333
f"run metadata: {e}"
358334
)
359335
else:
360-
submission_result = self.submit_pipeline(
361-
snapshot=snapshot,
362-
stack=stack,
363-
base_environment=base_environment,
364-
step_environments=step_environments,
365-
placeholder_run=placeholder_run,
366-
)
336+
if snapshot.is_dynamic:
337+
submission_result = self.submit_dynamic_pipeline(
338+
snapshot=snapshot,
339+
stack=stack,
340+
environment=base_environment,
341+
placeholder_run=placeholder_run,
342+
)
343+
else:
344+
step_environments = {}
345+
for (
346+
invocation_id,
347+
step,
348+
) in snapshot.step_configurations.items():
349+
from zenml.utils.env_utils import get_step_environment
350+
351+
step_environment = get_step_environment(
352+
step_config=step.config,
353+
stack=stack,
354+
)
355+
356+
combined_environment = base_environment.copy()
357+
combined_environment.update(step_environment)
358+
step_environments[invocation_id] = combined_environment
359+
360+
submission_result = self.submit_pipeline(
361+
snapshot=snapshot,
362+
stack=stack,
363+
base_environment=base_environment,
364+
step_environments=step_environments,
365+
placeholder_run=placeholder_run,
366+
)
367367
if placeholder_run:
368368
publish_pipeline_run_status_update(
369369
pipeline_run_id=placeholder_run.id,
@@ -427,54 +427,14 @@ def run_step(
427427
RunStoppedException: If the run was stopped.
428428
BaseException: If the step failed all retries.
429429
"""
430+
from zenml.pipelines.dynamic.runner import _run_step_sync
430431

431-
def _launch_step() -> None:
432-
assert self._active_snapshot
433-
434-
launcher = StepLauncher(
435-
snapshot=self._active_snapshot,
436-
step=step,
437-
orchestrator_run_id=self.get_orchestrator_run_id(),
438-
)
439-
launcher.launch()
440-
441-
if self.config.handles_step_retries:
442-
_launch_step()
443-
else:
444-
# The orchestrator subclass doesn't handle step retries, so we
445-
# handle it in-process instead
446-
retries = 0
447-
retry_config = step.config.retry
448-
max_retries = retry_config.max_retries if retry_config else 0
449-
delay = retry_config.delay if retry_config else 0
450-
backoff = retry_config.backoff if retry_config else 1
451-
452-
while retries <= max_retries:
453-
try:
454-
_launch_step()
455-
except RunStoppedException:
456-
# Don't retry if the run was stopped
457-
raise
458-
except BaseException:
459-
retries += 1
460-
if retries <= max_retries:
461-
logger.info(
462-
"Sleeping for %d seconds before retrying step `%s`.",
463-
delay,
464-
step.config.name,
465-
)
466-
time.sleep(delay)
467-
delay *= backoff
468-
else:
469-
if max_retries > 0:
470-
logger.error(
471-
"Failed to run step `%s` after %d retries.",
472-
step.config.name,
473-
max_retries,
474-
)
475-
raise
476-
else:
477-
break
432+
_run_step_sync(
433+
snapshot=self._active_snapshot,
434+
step=step,
435+
orchestrator_run_id=self.get_orchestrator_run_id(),
436+
retry=not self.config.handles_step_retries,
437+
)
478438

479439
@property
480440
def supports_dynamic_pipelines(self) -> bool:

src/zenml/orchestrators/publish_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def publish_successful_pipeline_run(
131131
run_update=PipelineRunUpdate(
132132
status=ExecutionStatus.COMPLETED,
133133
end_time=utc_now(),
134-
completed=True,
134+
finished=True,
135135
),
136136
)
137137

@@ -152,7 +152,7 @@ def publish_failed_pipeline_run(
152152
run_update=PipelineRunUpdate(
153153
status=ExecutionStatus.FAILED,
154154
end_time=utc_now(),
155-
completed=True,
155+
finished=True,
156156
),
157157
)
158158

src/zenml/orchestrators/step_launcher.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def __init__(
107107
snapshot: PipelineSnapshotResponse,
108108
step: Step,
109109
orchestrator_run_id: str,
110-
dynamic: bool = False,
111110
):
112111
"""Initializes the launcher.
113112
@@ -122,7 +121,6 @@ def __init__(
122121
self._snapshot = snapshot
123122
self._step = step
124123
self._orchestrator_run_id = orchestrator_run_id
125-
self._dynamic = dynamic
126124

127125
if not snapshot.stack:
128126
raise RuntimeError(
@@ -306,9 +304,10 @@ def launch(self) -> StepRunResponse:
306304
pipeline_run=pipeline_run,
307305
stack=self._stack,
308306
)
307+
dynamic_config = self._step if self._snapshot.is_dynamic else None
309308
step_run_request = request_factory.create_request(
310309
invocation_id=self._invocation_id,
311-
dynamic_config=self._step if self._dynamic else None,
310+
dynamic_config=dynamic_config,
312311
)
313312
step_run_request.logs = logs_model
314313

@@ -461,19 +460,19 @@ def _run_step(
461460
step_run_info=step_run_info,
462461
)
463462
else:
463+
from zenml.pipelines.dynamic.runner import (
464+
should_run_in_process,
465+
)
466+
464467
should_run_out_of_process = (
465468
self._snapshot.is_dynamic
466469
and self._step.config.in_process is False
467470
)
468471

469-
if (
470-
should_run_out_of_process
471-
and self._stack.orchestrator.supports_dynamic_out_of_process_steps
472+
if should_run_in_process(
473+
self._step,
474+
self._snapshot.pipeline_configuration.docker_settings,
472475
):
473-
self._run_step_with_dynamic_orchestrator(
474-
step_run_info=step_run_info
475-
)
476-
else:
477476
if should_run_out_of_process:
478477
logger.warning(
479478
"The %s does not support running dynamic out of "
@@ -490,6 +489,10 @@ def _run_step(
490489
input_artifacts=step_run.regular_inputs,
491490
output_artifact_uris=output_artifact_uris,
492491
)
492+
else:
493+
self._run_step_with_dynamic_orchestrator(
494+
step_run_info=step_run_info
495+
)
493496
except: # noqa: E722
494497
output_utils.remove_artifact_dirs(
495498
artifact_uris=list(output_artifact_uris.values())

0 commit comments

Comments
 (0)