Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .circleci/continue_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ workflows:
- athena
- fabric
- gcp-postgres
filters:
branches:
only:
- main
# filters:
# branches:
# only:
# - main
Comment on lines +313 to +316
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Revert

- ui_style
- ui_test
- vscode_test
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class EngineAdapter:
MAX_IDENTIFIER_LENGTH: t.Optional[int] = None
ATTACH_CORRELATION_ID = True
SUPPORTS_QUERY_EXECUTION_TRACKING = False
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = False

def __init__(
self,
Expand Down Expand Up @@ -2927,6 +2928,9 @@ def _check_identifier_length(self, expression: exp.Expression) -> None:
f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters"
)

def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
raise NotImplementedError()


class EngineAdapterWithIndexSupport(EngineAdapter):
SUPPORTS_INDEXES = True
Expand Down
22 changes: 22 additions & 0 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,28 @@ def table_exists(self, table_name: TableName) -> bool:
except NotFound:
return False

def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
from sqlmesh.utils.date import to_timestamp

datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)
for table_name in table_names:
table = exp.to_table(table_name)
datasets_to_tables[table.db].append(table.name)

results = []

for dataset, tables in datasets_to_tables.items():
query = (
f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE "
)
for i, table_name in enumerate(tables):
query += f"TABLE_ID = '{table_name}'"
if i < len(tables) - 1:
query += " OR "
results.extend(self.fetchall(query))

return [to_timestamp(row[0]) for row in results]

def _get_table(self, table_name: TableName) -> BigQueryTable:
"""
Returns a BigQueryTable object for the given table name.
Expand Down
16 changes: 16 additions & 0 deletions sqlmesh/core/engine_adapter/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
SUPPORTS_MANAGED_MODELS = True
CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
SUPPORTS_CREATE_DROP_CATALOG = True
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
Expand Down Expand Up @@ -666,3 +667,18 @@ def close(self) -> t.Any:
self._connection_pool.set_attribute(self.SNOWPARK, None)

return super().close()

def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
from sqlmesh.utils.date import to_timestamp

num_tables = len(table_names)

query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
for i, table_name in enumerate(table_names):
table = exp.to_table(table_name)
query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
if i < num_tables - 1:
query += " OR "

result = self.fetchall(query)
return [to_timestamp(row[0]) for row in result]
1 change: 1 addition & 0 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
allow_additive_snapshots=plan.allow_additive_models,
selected_snapshot_ids=stage.selected_snapshot_ids,
selected_models=plan.selected_models,
is_restatement_plan=bool(plan.restatements),
)
if errors:
raise PlanError("Plan application failed.")
Expand Down
74 changes: 71 additions & 3 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

if t.TYPE_CHECKING:
from sqlmesh.core.context import ExecutionContext
from sqlmesh.core._typing import TableName
from sqlmesh.core.engine_adapter import EngineAdapter

logger = logging.getLogger(__name__)
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
Expand Down Expand Up @@ -188,6 +190,46 @@ def merged_missing_intervals(
}
return snapshots_to_intervals

def can_skip_evaluation(self, snapshot: Snapshot, snapshots: t.Dict[str, Snapshot]) -> bool:
if not snapshot.last_altered_ts:
return False

from collections import defaultdict

parent_snapshots = {p for p in snapshots.values() if p.name != snapshot.name}
if len(parent_snapshots) != len(snapshot.node.depends_on):
# The mismatch can happen if e.g an external model is not registered in the project
return False

adapter_to_parent_snapshots: t.Dict[EngineAdapter, t.List[Snapshot]] = defaultdict(list)

for parent_snapshot in parent_snapshots:
if not parent_snapshot.is_external:
return False

adapter = self.snapshot_evaluator.get_adapter(parent_snapshot.model_gateway)
if not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
return False

adapter_to_parent_snapshots[adapter].append(parent_snapshot)

if not adapter_to_parent_snapshots:
return False

external_models_freshness: t.List[int] = []

for adapter, adapter_snapshots in adapter_to_parent_snapshots.items():
table_names: t.List[TableName] = [
exp.to_table(parent_snapshot.name, parent_snapshot.node.dialect)
for parent_snapshot in adapter_snapshots
]
external_models_freshness.extend(adapter.get_external_model_freshness(table_names))

return all(
snapshot.last_altered_ts > external_model_freshness
for external_model_freshness in external_models_freshness
)

def evaluate(
self,
snapshot: Snapshot,
Expand All @@ -200,6 +242,7 @@ def evaluate(
allow_destructive_snapshots: t.Optional[t.Set[str]] = None,
allow_additive_snapshots: t.Optional[t.Set[str]] = None,
target_table_exists: t.Optional[bool] = None,
is_restatement_plan: bool = False,
**kwargs: t.Any,
) -> t.List[AuditResult]:
"""Evaluate a snapshot and add the processed interval to the state sync.
Expand Down Expand Up @@ -251,7 +294,9 @@ def evaluate(
**kwargs,
)

self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)
self.state_sync.add_interval(
snapshot, start, end, is_dev=not is_deployable, last_altered_ts=now_timestamp()
)
return audit_results

def run(
Expand Down Expand Up @@ -335,6 +380,7 @@ def batch_intervals(
deployability_index: t.Optional[DeployabilityIndex],
environment_naming_info: EnvironmentNamingInfo,
dag: t.Optional[DAG[SnapshotId]] = None,
is_restatement_plan: bool = False,
) -> t.Dict[Snapshot, Intervals]:
dag = dag or snapshots_to_dag(merged_intervals)

Expand Down Expand Up @@ -374,6 +420,7 @@ def batch_intervals(
intervals,
context,
environment_naming_info,
is_restatement_plan=is_restatement_plan,
)
unready -= set(intervals)

Expand Down Expand Up @@ -422,6 +469,7 @@ def run_merged_intervals(
run_environment_statements: bool = False,
audit_only: bool = False,
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
is_restatement_plan: bool = False,
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
"""Runs precomputed batches of missing intervals.

Expand Down Expand Up @@ -455,9 +503,12 @@ def run_merged_intervals(
snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set)

batched_intervals = self.batch_intervals(
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
merged_intervals,
deployability_index,
environment_naming_info,
dag=snapshot_dag,
is_restatement_plan=is_restatement_plan,
)

self.console.start_evaluation_progress(
batched_intervals,
environment_naming_info,
Expand Down Expand Up @@ -542,6 +593,7 @@ def run_node(node: SchedulingUnit) -> None:
allow_additive_snapshots=allow_additive_snapshots,
target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
selected_models=selected_models,
is_restatement_plan=is_restatement_plan,
)

evaluation_duration_ms = now_timestamp() - execution_start_ts
Expand Down Expand Up @@ -913,6 +965,7 @@ def _check_ready_intervals(
intervals: Intervals,
context: ExecutionContext,
environment_naming_info: EnvironmentNamingInfo,
is_restatement_plan: bool = False,
) -> Intervals:
"""Checks if the intervals are ready for evaluation for the given snapshot.

Expand All @@ -934,13 +987,27 @@ def _check_ready_intervals(
if not (signals and signals.signals_to_kwargs):
return intervals

signal_names = signals.signals_to_kwargs.keys()

if (
is_restatement_plan
and len(signal_names) == 1
and next(iter(signal_names)) == "freshness"
):
# Freshness signal is not checked for restatement plans to allow users
# for an escape hatch in reevaluating models
return intervals

self.console.start_signal_progress(
snapshot,
self.default_catalog,
environment_naming_info or EnvironmentNamingInfo(),
)

for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()):
if is_restatement_plan and signal_name == "freshness":
continue

# Capture intervals before signal check for display
intervals_to_check = merge_intervals(intervals)

Expand All @@ -954,6 +1021,7 @@ def _check_ready_intervals(
python_env=signals.python_env,
dialect=snapshot.model.dialect,
path=snapshot.model._path,
snapshot=snapshot,
kwargs=kwargs,
)
except SQLMeshError as e:
Expand Down
43 changes: 42 additions & 1 deletion sqlmesh/core/signal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations


import typing as t
from sqlmesh.utils import UniqueKeyDict, registry_decorator

if t.TYPE_CHECKING:
from sqlmesh.core.context import ExecutionContext
from sqlmesh.core.snapshot.definition import Snapshot
from sqlmesh.utils.date import DatetimeRanges


class signal(registry_decorator):
"""Specifies a function which intervals are ready from a list of scheduled intervals.
Expand Down Expand Up @@ -33,3 +38,39 @@ class signal(registry_decorator):


SignalRegistry = UniqueKeyDict[str, signal]


@signal()
def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
deployability_index = context.deployability_index
adapter = context.engine_adapter

if not deployability_index or not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
return True

last_altered_ts = (
snapshot.last_altered_ts
if deployability_index.is_deployable(snapshot)
else snapshot.dev_last_altered_ts
)
if not last_altered_ts:
return True

parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
p.is_external for p in parent_snapshots
):
# The mismatch can happen if e.g an external model is not registered in the project
return True

# Finding new data means that the upstream depedencies have been altered
# since the last time the model was evaluated
upstream_dep_has_new_data = any(
upstream_last_altered_ts > last_altered_ts
for upstream_last_altered_ts in adapter.get_external_model_freshness(
[p.name for p in parent_snapshots]
)
)

# Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
return upstream_dep_has_new_data
Loading