Skip to content

Commit ac5446b

Browse files
authored
Fix: Cache upstream dependencies when building the evaluation DAG in scheduler (#5569)
1 parent c57b048 commit ac5446b

File tree

2 files changed

+118
-13
lines changed

2 files changed

+118
-13
lines changed

sqlmesh/core/scheduler.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def _dag(
659659
}
660660
snapshots_to_create = snapshots_to_create or set()
661661
original_snapshots_to_create = snapshots_to_create.copy()
662+
upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {}
662663

663664
snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
664665
dag = DAG[SchedulingUnit]()
@@ -670,12 +671,15 @@ def _dag(
670671
snapshot = self.snapshots_by_name[snapshot_id.name]
671672
intervals = intervals_per_snapshot.get(snapshot.name, [])
672673

673-
upstream_dependencies: t.List[SchedulingUnit] = []
674+
upstream_dependencies: t.Set[SchedulingUnit] = set()
674675

675676
for p_sid in snapshot.parents:
676-
upstream_dependencies.extend(
677+
upstream_dependencies.update(
677678
self._find_upstream_dependencies(
678-
p_sid, intervals_per_snapshot, original_snapshots_to_create
679+
p_sid,
680+
intervals_per_snapshot,
681+
original_snapshots_to_create,
682+
upstream_dependencies_cache,
679683
)
680684
)
681685

@@ -726,29 +730,43 @@ def _find_upstream_dependencies(
726730
parent_sid: SnapshotId,
727731
intervals_per_snapshot: t.Dict[str, Intervals],
728732
snapshots_to_create: t.Set[SnapshotId],
729-
) -> t.List[SchedulingUnit]:
733+
cache: t.Optional[t.Dict[SnapshotId, t.Set[SchedulingUnit]]] = None,
734+
) -> t.Set[SchedulingUnit]:
735+
cache = cache or {}
730736
if parent_sid not in self.snapshots:
731-
return []
737+
return set()
738+
if parent_sid in cache:
739+
return cache[parent_sid]
732740

733741
p_intervals = intervals_per_snapshot.get(parent_sid.name, [])
734742

743+
parent_node: t.Optional[SchedulingUnit] = None
735744
if p_intervals:
736745
if len(p_intervals) > 1:
737-
return [DummyNode(snapshot_name=parent_sid.name)]
738-
interval = p_intervals[0]
739-
return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
740-
if parent_sid in snapshots_to_create:
741-
return [CreateNode(snapshot_name=parent_sid.name)]
746+
parent_node = DummyNode(snapshot_name=parent_sid.name)
747+
else:
748+
interval = p_intervals[0]
749+
parent_node = EvaluateNode(
750+
snapshot_name=parent_sid.name, interval=interval, batch_index=0
751+
)
752+
elif parent_sid in snapshots_to_create:
753+
parent_node = CreateNode(snapshot_name=parent_sid.name)
754+
755+
if parent_node is not None:
756+
cache[parent_sid] = {parent_node}
757+
return {parent_node}
758+
742759
# This snapshot has no intervals and doesn't need creation which means
743760
# that it can be a transitive dependency
744-
transitive_deps: t.List[SchedulingUnit] = []
761+
transitive_deps: t.Set[SchedulingUnit] = set()
745762
parent_snapshot = self.snapshots[parent_sid]
746763
for grandparent_sid in parent_snapshot.parents:
747-
transitive_deps.extend(
764+
transitive_deps.update(
748765
self._find_upstream_dependencies(
749-
grandparent_sid, intervals_per_snapshot, snapshots_to_create
766+
grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache
750767
)
751768
)
769+
cache[parent_sid] = transitive_deps
752770
return transitive_deps
753771

754772
def _run_or_audit(

tests/core/test_scheduler.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,3 +1126,90 @@ def test_dag_multiple_chain_transitive_deps(mocker: MockerFixture, make_snapshot
11261126
)
11271127
},
11281128
}
1129+
1130+
1131+
def test_dag_upstream_dependency_caching_with_complex_diamond(mocker: MockerFixture, make_snapshot):
1132+
r"""
1133+
Test that the upstream dependency caching correctly handles a complex diamond dependency graph.
1134+
1135+
Dependency graph:
1136+
A (has intervals)
1137+
/ \
1138+
B C (no intervals - transitive)
1139+
/ \ / \
1140+
D E F (no intervals - transitive)
1141+
\ / \ /
1142+
G H (has intervals - selected)
1143+
1144+
This creates multiple paths from G and H to A. Without caching, A's dependencies would be
1145+
computed multiple times (once for each path). With caching, they should be computed once
1146+
and reused.
1147+
"""
1148+
snapshots = {}
1149+
1150+
for name in ["a", "b", "c", "d", "e", "f", "g", "h"]:
1151+
snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id")))
1152+
snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING)
1153+
1154+
# A is the root
1155+
snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
1156+
snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
1157+
1158+
# Middle layer: D, E, F depend on B and/or C
1159+
snapshots["d"] = snapshots["d"].model_copy(update={"parents": (snapshots["b"].snapshot_id,)})
1160+
snapshots["e"] = snapshots["e"].model_copy(
1161+
update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)}
1162+
)
1163+
snapshots["f"] = snapshots["f"].model_copy(update={"parents": (snapshots["c"].snapshot_id,)})
1164+
1165+
# Bottom layer: G and H depend on D/E and E/F respectively
1166+
snapshots["g"] = snapshots["g"].model_copy(
1167+
update={"parents": (snapshots["d"].snapshot_id, snapshots["e"].snapshot_id)}
1168+
)
1169+
snapshots["h"] = snapshots["h"].model_copy(
1170+
update={"parents": (snapshots["e"].snapshot_id, snapshots["f"].snapshot_id)}
1171+
)
1172+
1173+
scheduler = Scheduler(
1174+
snapshots=list(snapshots.values()),
1175+
snapshot_evaluator=mocker.Mock(),
1176+
state_sync=mocker.Mock(),
1177+
default_catalog=None,
1178+
)
1179+
1180+
batched_intervals = {
1181+
snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1182+
snapshots["g"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1183+
snapshots["h"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1184+
}
1185+
1186+
full_dag = snapshots_to_dag(snapshots.values())
1187+
dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag)
1188+
1189+
# Verify the DAG structure:
1190+
# 1. A should be evaluated first (no dependencies)
1191+
# 2. Both G and H should depend on A (through transitive dependencies)
1192+
# 3. Transitive nodes (B, C, D, E, F) should not appear as separate evaluation nodes
1193+
expected_a_node = EvaluateNode(
1194+
snapshot_name='"a"',
1195+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1196+
batch_index=0,
1197+
)
1198+
1199+
expected_g_node = EvaluateNode(
1200+
snapshot_name='"g"',
1201+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1202+
batch_index=0,
1203+
)
1204+
1205+
expected_h_node = EvaluateNode(
1206+
snapshot_name='"h"',
1207+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1208+
batch_index=0,
1209+
)
1210+
1211+
assert dag.graph == {
1212+
expected_a_node: set(),
1213+
expected_g_node: {expected_a_node},
1214+
expected_h_node: {expected_a_node},
1215+
}

0 commit comments

Comments
 (0)