Skip to content

Commit 64db71b

Browse files
committed
Feat: Skip evaluation if upstream external model has not changed
1 parent 37523dc commit 64db71b

File tree

10 files changed

+271
-4
lines changed

10 files changed

+271
-4
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class EngineAdapter:
119119
MAX_IDENTIFIER_LENGTH: t.Optional[int] = None
120120
ATTACH_CORRELATION_ID = True
121121
SUPPORTS_QUERY_EXECUTION_TRACKING = False
122+
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = False
122123

123124
def __init__(
124125
self,
@@ -2873,6 +2874,9 @@ def _check_identifier_length(self, expression: exp.Expression) -> None:
28732874
f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters"
28742875
)
28752876

2877+
def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
2878+
raise NotImplementedError()
2879+
28762880

28772881
class EngineAdapterWithIndexSupport(EngineAdapter):
28782882
SUPPORTS_INDEXES = True

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
6767
MAX_TABLE_COMMENT_LENGTH = 1024
6868
MAX_COLUMN_COMMENT_LENGTH = 1024
6969
SUPPORTS_QUERY_EXECUTION_TRACKING = True
70+
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = True
7071
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
7172

7273
SCHEMA_DIFFER_KWARGS = {
@@ -753,6 +754,28 @@ def table_exists(self, table_name: TableName) -> bool:
753754
except NotFound:
754755
return False
755756

757+
def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
758+
from sqlmesh.utils.date import to_timestamp
759+
760+
datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)
761+
for table_name in table_names:
762+
table = exp.to_table(table_name)
763+
datasets_to_tables[table.db].append(table.name)
764+
765+
results = []
766+
767+
for dataset, tables in datasets_to_tables.items():
768+
query = (
769+
f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE "
770+
)
771+
for i, table_name in enumerate(tables):
772+
query += f"TABLE_ID = '{table_name}'"
773+
if i < len(tables) - 1:
774+
query += " OR "
775+
results.extend(self.fetchall(query))
776+
777+
return [to_timestamp(row[0]) for row in results]
778+
756779
def _get_table(self, table_name: TableName) -> BigQueryTable:
757780
"""
758781
Returns a BigQueryTable object for the given table name.

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
5454
SUPPORTS_MANAGED_MODELS = True
5555
CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
5656
SUPPORTS_CREATE_DROP_CATALOG = True
57+
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = True
5758
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
5859
SCHEMA_DIFFER_KWARGS = {
5960
"parameterized_type_defaults": {
@@ -665,3 +666,18 @@ def close(self) -> t.Any:
665666
self._connection_pool.set_attribute(self.SNOWPARK, None)
666667

667668
return super().close()
669+
670+
def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
671+
from sqlmesh.utils.date import to_timestamp
672+
673+
num_tables = len(table_names)
674+
675+
query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
676+
for i, table_name in enumerate(table_names):
677+
table = exp.to_table(table_name)
678+
query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
679+
if i < num_tables - 1:
680+
query += " OR "
681+
682+
result = self.fetchall(query)
683+
return [to_timestamp(row[0]) for row in result]

sqlmesh/core/plan/evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
257257
allow_destructive_snapshots=plan.allow_destructive_models,
258258
allow_additive_snapshots=plan.allow_additive_models,
259259
selected_snapshot_ids=stage.selected_snapshot_ids,
260+
is_restatement_plan=bool(plan.restatements),
260261
)
261262
if errors:
262263
raise PlanError("Plan application failed.")

sqlmesh/core/scheduler.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454

5555
if t.TYPE_CHECKING:
5656
from sqlmesh.core.context import ExecutionContext
57+
from sqlmesh.core._typing import TableName
58+
from sqlmesh.core.engine_adapter import EngineAdapter
5759

5860
logger = logging.getLogger(__name__)
5961
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
@@ -188,6 +190,46 @@ def merged_missing_intervals(
188190
}
189191
return snapshots_to_intervals
190192

193+
def can_skip_evaluation(self, snapshot: Snapshot, snapshots: t.Dict[str, Snapshot]) -> bool:
194+
if not snapshot.last_altered_ts:
195+
return False
196+
197+
from collections import defaultdict
198+
199+
parent_snapshots = {p for p in snapshots.values() if p.name != snapshot.name}
200+
if len(parent_snapshots) != len(snapshot.node.depends_on):
201+
# The mismatch can happen if e.g an external model is not registered in the project
202+
return False
203+
204+
adapter_to_parent_snapshots: t.Dict[EngineAdapter, t.List[Snapshot]] = defaultdict(list)
205+
206+
for parent_snapshot in parent_snapshots:
207+
if not parent_snapshot.is_external:
208+
return False
209+
210+
adapter = self.snapshot_evaluator.get_adapter(parent_snapshot.model_gateway)
211+
if not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
212+
return False
213+
214+
adapter_to_parent_snapshots[adapter].append(parent_snapshot)
215+
216+
if not adapter_to_parent_snapshots:
217+
return False
218+
219+
external_models_freshness: t.List[int] = []
220+
221+
for adapter, adapter_snapshots in adapter_to_parent_snapshots.items():
222+
table_names: t.List[TableName] = [
223+
exp.to_table(parent_snapshot.name, parent_snapshot.node.dialect)
224+
for parent_snapshot in adapter_snapshots
225+
]
226+
external_models_freshness.extend(adapter.get_external_model_freshness(table_names))
227+
228+
return all(
229+
snapshot.last_altered_ts > external_model_freshness
230+
for external_model_freshness in external_models_freshness
231+
)
232+
191233
def evaluate(
192234
self,
193235
snapshot: Snapshot,
@@ -200,6 +242,7 @@ def evaluate(
200242
allow_destructive_snapshots: t.Optional[t.Set[str]] = None,
201243
allow_additive_snapshots: t.Optional[t.Set[str]] = None,
202244
target_table_exists: t.Optional[bool] = None,
245+
is_restatement_plan: bool = False,
203246
**kwargs: t.Any,
204247
) -> t.List[AuditResult]:
205248
"""Evaluate a snapshot and add the processed interval to the state sync.
@@ -224,6 +267,14 @@ def evaluate(
224267

225268
snapshots = parent_snapshots_by_name(snapshot, self.snapshots)
226269

270+
if not is_restatement_plan and self.can_skip_evaluation(snapshot, snapshots):
271+
logger.info(f"""
272+
Skipping evaluation for snapshot {snapshot.name} as it depends on external models
273+
that have not been updated since the last run.
274+
""")
275+
276+
return []
277+
227278
is_deployable = deployability_index.is_deployable(snapshot)
228279

229280
wap_id = self.snapshot_evaluator.evaluate(
@@ -251,7 +302,9 @@ def evaluate(
251302
**kwargs,
252303
)
253304

254-
self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)
305+
self.state_sync.add_interval(
306+
snapshot, start, end, is_dev=not is_deployable, last_altered_ts=now_timestamp()
307+
)
255308
return audit_results
256309

257310
def run(
@@ -421,6 +474,7 @@ def run_merged_intervals(
421474
run_environment_statements: bool = False,
422475
audit_only: bool = False,
423476
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
477+
is_restatement_plan: bool = False,
424478
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
425479
"""Runs precomputed batches of missing intervals.
426480
@@ -526,6 +580,7 @@ def run_node(node: SchedulingUnit) -> None:
526580
allow_destructive_snapshots=allow_destructive_snapshots,
527581
allow_additive_snapshots=allow_additive_snapshots,
528582
target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
583+
is_restatement_plan=is_restatement_plan,
529584
)
530585

531586
evaluation_duration_ms = now_timestamp() - execution_start_ts

sqlmesh/core/snapshot/definition.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class SnapshotIntervals(PydanticModel):
185185
intervals: Intervals = []
186186
dev_intervals: Intervals = []
187187
pending_restatement_intervals: Intervals = []
188+
last_altered_ts: t.Optional[int] = None
188189

189190
@property
190191
def snapshot_id(self) -> t.Optional[SnapshotId]:
@@ -652,6 +653,9 @@ class Snapshot(PydanticModel, SnapshotInfoMixin):
652653
dev_table_suffix: str = "dev"
653654
table_naming_convention: TableNamingConvention = TableNamingConvention.default
654655
forward_only: bool = False
656+
# Physical table last modified timestamp, not to be confused with the "updated_ts" field
657+
# which is for the snapshot record itself
658+
last_altered_ts: t.Optional[int] = None
655659

656660
@field_validator("ttl")
657661
@classmethod
@@ -690,6 +694,12 @@ def hydrate_with_intervals_by_version(
690694
)
691695
for interval in snapshot_intervals:
692696
snapshot.merge_intervals(interval)
697+
698+
if interval.last_altered_ts:
699+
snapshot.last_altered_ts = max(
700+
snapshot.last_altered_ts or -1, interval.last_altered_ts
701+
)
702+
693703
result.append(snapshot)
694704

695705
return result

sqlmesh/core/state_sync/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def add_interval(
477477
start: TimeLike,
478478
end: TimeLike,
479479
is_dev: bool = False,
480+
last_altered_ts: t.Optional[int] = None,
480481
) -> None:
481482
"""Add an interval to a snapshot and sync it to the store.
482483
@@ -485,6 +486,7 @@ def add_interval(
485486
start: The start of the interval to add.
486487
end: The end of the interval to add.
487488
is_dev: Indicates whether the given interval is being added while in development mode
489+
last_altered_ts: The timestamp of the last modification of the physical table
488490
"""
489491
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False)
490492
if not snapshot.version:
@@ -497,6 +499,7 @@ def add_interval(
497499
dev_version=snapshot.dev_version,
498500
intervals=intervals if not is_dev else [],
499501
dev_intervals=intervals if is_dev else [],
502+
last_altered_ts=last_altered_ts,
500503
)
501504
self.add_snapshots_intervals([snapshot_intervals])
502505

sqlmesh/core/state_sync/db/interval.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
"is_removed": exp.DataType.build("boolean"),
6161
"is_compacted": exp.DataType.build("boolean"),
6262
"is_pending_restatement": exp.DataType.build("boolean"),
63+
"last_altered_ts": exp.DataType.build("bigint"),
6364
}
6465

6566
def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
@@ -215,13 +216,23 @@ def _push_snapshot_intervals(
215216
for start_ts, end_ts in snapshot.intervals:
216217
new_intervals.append(
217218
_interval_to_df(
218-
snapshot, start_ts, end_ts, is_dev=False, is_compacted=is_compacted
219+
snapshot,
220+
start_ts,
221+
end_ts,
222+
is_dev=False,
223+
is_compacted=is_compacted,
224+
last_altered_ts=snapshot.last_altered_ts,
219225
)
220226
)
221227
for start_ts, end_ts in snapshot.dev_intervals:
222228
new_intervals.append(
223229
_interval_to_df(
224-
snapshot, start_ts, end_ts, is_dev=True, is_compacted=is_compacted
230+
snapshot,
231+
start_ts,
232+
end_ts,
233+
is_dev=True,
234+
is_compacted=is_compacted,
235+
last_altered_ts=snapshot.last_altered_ts,
225236
)
226237
)
227238

@@ -236,6 +247,7 @@ def _push_snapshot_intervals(
236247
is_dev=False,
237248
is_compacted=is_compacted,
238249
is_pending_restatement=True,
250+
last_altered_ts=snapshot.last_altered_ts,
239251
)
240252
)
241253

@@ -284,6 +296,7 @@ def _get_snapshot_intervals(
284296
is_dev,
285297
is_removed,
286298
is_pending_restatement,
299+
last_altered_ts,
287300
) in rows:
288301
interval_ids.add(interval_id)
289302
merge_key = (name, version, dev_version, identifier)
@@ -296,6 +309,12 @@ def _get_snapshot_intervals(
296309
identifier=identifier,
297310
version=version,
298311
dev_version=dev_version,
312+
last_altered_ts=last_altered_ts,
313+
)
314+
315+
if last_altered_ts:
316+
intervals[merge_key].last_altered_ts = max(
317+
intervals[merge_key].last_altered_ts or 0, last_altered_ts
299318
)
300319

301320
if pending_restatement_interval_merge_key not in intervals:
@@ -340,6 +359,7 @@ def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select:
340359
"is_dev",
341360
"is_removed",
342361
"is_pending_restatement",
362+
"last_altered_ts",
343363
)
344364
.from_(exp.to_table(self.intervals_table).as_("intervals"))
345365
.order_by(
@@ -458,6 +478,7 @@ def _interval_to_df(
458478
is_removed: bool = False,
459479
is_compacted: bool = False,
460480
is_pending_restatement: bool = False,
481+
last_altered_ts: t.Optional[int] = None,
461482
) -> t.Dict[str, t.Any]:
462483
return {
463484
"id": random_id(),
@@ -472,4 +493,5 @@ def _interval_to_df(
472493
"is_removed": is_removed,
473494
"is_compacted": is_compacted,
474495
"is_pending_restatement": is_pending_restatement,
496+
"last_altered_ts": last_altered_ts,
475497
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Add dev version to the intervals table."""
2+
3+
from sqlglot import exp
4+
5+
6+
def migrate(state_sync, **kwargs): # type: ignore
7+
engine_adapter = state_sync.engine_adapter
8+
schema = state_sync.schema
9+
intervals_table = "_intervals"
10+
if schema:
11+
intervals_table = f"{schema}.{intervals_table}"
12+
13+
alter_table_exp = exp.Alter(
14+
this=exp.to_table(intervals_table),
15+
kind="TABLE",
16+
actions=[
17+
exp.ColumnDef(
18+
this=exp.to_column("last_altered_ts"),
19+
kind=exp.DataType.build("BIGINT", dialect=engine_adapter.dialect),
20+
)
21+
],
22+
)
23+
engine_adapter.execute(alter_table_exp)

0 commit comments

Comments
 (0)