Skip to content

Commit 368c5dd

Browse files
authored
Feat(sqlmesh_dbt): Select based on dbt name, not sqlmesh name (#5420)
1 parent 0bda998 commit 368c5dd

File tree

9 files changed

+293
-30
lines changed

9 files changed

+293
-30
lines changed

sqlmesh/core/context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
from sqlmesh.core.reference import ReferenceGraph
9494
from sqlmesh.core.scheduler import Scheduler, CompletionStatus
9595
from sqlmesh.core.schema_loader import create_external_models_file
96-
from sqlmesh.core.selector import Selector
96+
from sqlmesh.core.selector import Selector, NativeSelector
9797
from sqlmesh.core.snapshot import (
9898
DeployabilityIndex,
9999
Snapshot,
@@ -368,6 +368,7 @@ def __init__(
368368
load: bool = True,
369369
users: t.Optional[t.List[User]] = None,
370370
config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
371+
selector: t.Optional[t.Type[Selector]] = None,
371372
):
372373
self.configs = (
373374
config
@@ -390,6 +391,7 @@ def __init__(
390391
self._engine_adapter: t.Optional[EngineAdapter] = None
391392
self._linters: t.Dict[str, Linter] = {}
392393
self._loaded: bool = False
394+
self._selector_cls = selector or NativeSelector
393395

394396
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
395397

@@ -2893,7 +2895,7 @@ def _new_state_sync(self) -> StateSync:
28932895
def _new_selector(
28942896
self, models: t.Optional[UniqueKeyDict[str, Model]] = None, dag: t.Optional[DAG[str]] = None
28952897
) -> Selector:
2896-
return Selector(
2898+
return self._selector_cls(
28972899
self.state_reader,
28982900
models=models or self._models,
28992901
context_path=self.path,

sqlmesh/core/selector.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import fnmatch
44
import typing as t
55
from pathlib import Path
6+
from itertools import zip_longest
7+
import abc
68

79
from sqlglot import exp
810
from sqlglot.errors import ParseError
@@ -26,7 +28,7 @@
2628
from sqlmesh.core.state_sync import StateReader
2729

2830

29-
class Selector:
31+
class Selector(abc.ABC):
3032
def __init__(
3133
self,
3234
state_reader: StateReader,
@@ -167,13 +169,13 @@ def get_model(fqn: str) -> t.Optional[Model]:
167169
def expand_model_selections(
168170
self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None
169171
) -> t.Set[str]:
170-
"""Expands a set of model selections into a set of model names.
172+
"""Expands a set of model selections into a set of model fqns that can be looked up in the Context.
171173
172174
Args:
173175
model_selections: A set of model selections.
174176
175177
Returns:
176-
A set of model names.
178+
A set of model fqns.
177179
"""
178180

179181
node = parse(" | ".join(f"({s})" for s in model_selections))
@@ -194,10 +196,9 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
194196
return {
195197
fqn
196198
for fqn, model in all_models.items()
197-
if fnmatch.fnmatchcase(model.name, node.this)
199+
if fnmatch.fnmatchcase(self._model_name(model), node.this)
198200
}
199-
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
200-
return {fqn} if fqn in all_models else set()
201+
return self._pattern_to_model_fqns(pattern, all_models)
201202
if isinstance(node, exp.And):
202203
return evaluate(node.left) & evaluate(node.right)
203204
if isinstance(node, exp.Or):
@@ -241,6 +242,70 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
241242

242243
return evaluate(node)
243244

245+
@abc.abstractmethod
246+
def _model_name(self, model: Model) -> str:
247+
"""Given a model, return the name that a selector pattern contining wildcards should be fnmatch'd on"""
248+
pass
249+
250+
@abc.abstractmethod
251+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
252+
"""Given a pattern, return the keys of the matching models from :all_models"""
253+
pass
254+
255+
256+
class NativeSelector(Selector):
257+
"""Implementation of selectors that matches objects based on SQLMesh native names"""
258+
259+
def _model_name(self, model: Model) -> str:
260+
return model.name
261+
262+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
263+
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
264+
return {fqn} if fqn in all_models else set()
265+
266+
267+
class DbtSelector(Selector):
268+
"""Implementation of selectors that matches objects based on the DBT names instead of the SQLMesh native names"""
269+
270+
def _model_name(self, model: Model) -> str:
271+
if dbt_fqn := model.dbt_fqn:
272+
return dbt_fqn
273+
raise SQLMeshError("dbt node information must be populated to use dbt selectors")
274+
275+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
276+
# a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers"
277+
# but not a model called "jaffle_shop.customers.staging"
278+
# also a pattern like "aging" should not match "staging" so we need to consider components; not substrings
279+
pattern_components = pattern.split(".")
280+
first_pattern_component = pattern_components[0]
281+
matches = set()
282+
for fqn, model in all_models.items():
283+
if not model.dbt_fqn:
284+
continue
285+
286+
dbt_fqn_components = model.dbt_fqn.split(".")
287+
try:
288+
starting_idx = dbt_fqn_components.index(first_pattern_component)
289+
except ValueError:
290+
continue
291+
for pattern_component, fqn_component in zip_longest(
292+
pattern_components, dbt_fqn_components[starting_idx:]
293+
):
294+
if pattern_component and not fqn_component:
295+
# the pattern still goes but we have run out of fqn components to match; no match
296+
break
297+
if fqn_component and not pattern_component:
298+
# all elements of the pattern have matched elements of the fqn; match
299+
matches.add(fqn)
300+
break
301+
if pattern_component != fqn_component:
302+
# the pattern explicitly doesnt match a component; no match
303+
break
304+
else:
305+
# called if no explicit break, indicating all components of the pattern matched all components of the fqn
306+
matches.add(fqn)
307+
return matches
308+
244309

245310
class SelectorDialect(Dialect):
246311
IDENTIFIERS_CAN_START_WITH_DIGIT = True

sqlmesh_dbt/operations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _plan_builder_options(
185185
options.update(
186186
dict(
187187
# Add every selected model as a restatement to force them to get repopulated from scratch
188-
restate_models=list(self.context.models)
188+
restate_models=[m.dbt_fqn for m in self.context.models.values() if m.dbt_fqn]
189189
if not select_models
190190
else select_models,
191191
# by default in SQLMesh, restatements only operate on what has been committed to state.
@@ -231,6 +231,7 @@ def create(
231231
from sqlmesh.core.console import set_console
232232
from sqlmesh_dbt.console import DbtCliConsole
233233
from sqlmesh.utils.errors import SQLMeshError
234+
from sqlmesh.core.selector import DbtSelector
234235

235236
# clear any existing handlers set up by click/rich as defaults so that once SQLMesh logging config is applied,
236237
# we dont get duplicate messages logged from things like console.log_warning()
@@ -250,6 +251,8 @@ def create(
250251
paths=[project_dir],
251252
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
252253
load=True,
254+
# DbtSelector selects based on dbt model fqn's rather than SQLMesh model names
255+
selector=DbtSelector,
253256
)
254257

255258
dbt_loader = sqlmesh_context._loaders[0]

tests/core/test_selector.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlmesh.core.environment import Environment
1313
from sqlmesh.core.model import Model, SqlModel
1414
from sqlmesh.core.model.common import ParsableSql
15-
from sqlmesh.core.selector import Selector
15+
from sqlmesh.core.selector import NativeSelector
1616
from sqlmesh.core.snapshot import SnapshotChangeCategory
1717
from sqlmesh.utils import UniqueKeyDict
1818
from sqlmesh.utils.date import now_timestamp
@@ -88,7 +88,7 @@ def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t.
8888
local_models[modified_model_v2.fqn] = modified_model_v2.copy(
8989
update={"mapping_schema": added_model_schema}
9090
)
91-
selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog)
91+
selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog)
9292

9393
_assert_models_equal(
9494
selector.select_models(["db.added_model"], env_name),
@@ -243,7 +243,7 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot)
243243

244244
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
245245
local_models[modified_model_v2.fqn] = modified_model_v2
246-
selector = Selector(state_reader_mock, local_models)
246+
selector = NativeSelector(state_reader_mock, local_models)
247247

248248
_assert_models_equal(
249249
selector.select_models(["*.modified_model"], env_name, fallback_env_name="prod"),
@@ -305,7 +305,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot):
305305
local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}})
306306
local_models[local_child.fqn] = local_child
307307

308-
selector = Selector(state_reader_mock, local_models)
308+
selector = NativeSelector(state_reader_mock, local_models)
309309

310310
selected = selector.select_models(["db.parent"], env_name)
311311
assert selected[local_child.fqn].render_query() != child.render_query()
@@ -339,7 +339,7 @@ def test_select_models_missing_env(mocker: MockerFixture, make_snapshot):
339339
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
340340
local_models[model.fqn] = model
341341

342-
selector = Selector(state_reader_mock, local_models)
342+
selector = NativeSelector(state_reader_mock, local_models)
343343

344344
assert selector.select_models([model.name], "missing_env").keys() == {model.fqn}
345345
assert not selector.select_models(["missing"], "missing_env")
@@ -563,7 +563,7 @@ def test_expand_model_selections(
563563
)
564564
models[model.fqn] = model
565565

566-
selector = Selector(mocker.Mock(), models)
566+
selector = NativeSelector(mocker.Mock(), models)
567567
assert selector.expand_model_selections(selections) == output
568568

569569

@@ -576,7 +576,7 @@ def test_model_selection_normalized(mocker: MockerFixture, make_snapshot):
576576
dialect="bigquery",
577577
)
578578
models[model.fqn] = model
579-
selector = Selector(mocker.Mock(), models, dialect="bigquery")
579+
selector = NativeSelector(mocker.Mock(), models, dialect="bigquery")
580580
assert selector.expand_model_selections(["db.test_Model"]) == {'"db"."test_Model"'}
581581

582582

@@ -624,7 +624,7 @@ def test_expand_git_selection(
624624
git_client_mock.list_uncommitted_changed_files.return_value = []
625625
git_client_mock.list_committed_changed_files.return_value = [model_a._path, model_c._path]
626626

627-
selector = Selector(mocker.Mock(), models)
627+
selector = NativeSelector(mocker.Mock(), models)
628628
selector._git_client = git_client_mock
629629

630630
assert selector.expand_model_selections(expressions) == expected_fqns
@@ -658,7 +658,7 @@ def test_select_models_with_external_parent(mocker: MockerFixture):
658658
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
659659
local_models[added_model.fqn] = added_model
660660

661-
selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog)
661+
selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog)
662662

663663
expanded_selections = selector.expand_model_selections(["+*added_model*"])
664664
assert expanded_selections == {added_model.fqn}
@@ -699,7 +699,7 @@ def test_select_models_local_tags_take_precedence_over_remote(
699699
local_models[local_existing.fqn] = local_existing
700700
local_models[local_new.fqn] = local_new
701701

702-
selector = Selector(state_reader_mock, local_models)
702+
selector = NativeSelector(state_reader_mock, local_models)
703703

704704
selected = selector.select_models(["tag:a"], env_name)
705705

tests/dbt/cli/test_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1919

2020

2121
def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
22-
result = invoke_cli(["list", "--select", "main.raw_customers+"])
22+
result = invoke_cli(["list", "--select", "raw_customers+"])
2323

2424
assert result.exit_code == 0
2525
assert not result.exception
@@ -34,7 +34,7 @@ def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Resul
3434

3535
def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
3636
# single exclude
37-
result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"])
37+
result = invoke_cli(["list", "--select", "raw_customers+", "--exclude", "orders"])
3838

3939
assert result.exit_code == 0
4040
assert not result.exception
@@ -49,8 +49,8 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..
4949

5050
# multiple exclude
5151
for args in (
52-
["--select", "main.stg_orders+", "--exclude", "main.customers", "--exclude", "main.orders"],
53-
["--select", "main.stg_orders+", "--exclude", "main.customers main.orders"],
52+
["--select", "stg_orders+", "--exclude", "customers", "--exclude", "orders"],
53+
["--select", "stg_orders+", "--exclude", "customers orders"],
5454
):
5555
result = invoke_cli(["list", *args])
5656
assert result.exit_code == 0

tests/dbt/cli/test_operations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
138138
assert plan.selected_models_to_backfill is None
139139
assert {s.name for s in plan.snapshots} == {k for k in operations.context.snapshots}
140140

141-
plan = operations.run(select=["main.stg_orders+"])
141+
plan = operations.run(select=["stg_orders+"])
142142
assert plan.environment.name == "prod"
143143
assert console.no_prompts is True
144144
assert console.no_diff is True
@@ -155,7 +155,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
155155
plan.selected_models_to_backfill | {standalone_audit_name}
156156
)
157157

158-
plan = operations.run(select=["main.stg_orders+"], exclude=["main.customers"])
158+
plan = operations.run(select=["stg_orders+"], exclude=["customers"])
159159
assert plan.environment.name == "prod"
160160
assert console.no_prompts is True
161161
assert console.no_diff is True
@@ -171,7 +171,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
171171
plan.selected_models_to_backfill | {standalone_audit_name}
172172
)
173173

174-
plan = operations.run(exclude=["main.customers"])
174+
plan = operations.run(exclude=["customers"])
175175
assert plan.environment.name == "prod"
176176
assert console.no_prompts is True
177177
assert console.no_diff is True
@@ -238,7 +238,7 @@ def test_run_option_mapping_dev(jaffle_shop_duckdb: Path):
238238
assert plan.skip_backfill is True
239239
assert plan.selected_models_to_backfill == {'"jaffle_shop"."main"."new_model"'}
240240

241-
plan = operations.run(environment="dev", select=["main.stg_orders+"])
241+
plan = operations.run(environment="dev", select=["stg_orders+"])
242242
assert plan.environment.name == "dev"
243243
assert console.no_prompts is True
244244
assert console.no_diff is True
@@ -325,7 +325,7 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path):
325325
console = PlanCapturingConsole()
326326
operations.context.console = console
327327

328-
plan = operations.run(select=["main.stg_customers"], full_refresh=True)
328+
plan = operations.run(select=["stg_customers"], full_refresh=True)
329329
assert len(plan.restatements) == 1
330330
assert list(plan.restatements)[0].name == '"jaffle_shop"."main"."stg_customers"'
331331

tests/dbt/cli/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[...
2727
assert result.exit_code == 0
2828
assert "main.orders" in result.output
2929

30-
result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"])
30+
result = invoke_cli(["run", "--select", "raw_customers+", "--exclude", "orders"])
3131

3232
assert result.exit_code == 0
3333
assert not result.exception

0 commit comments

Comments
 (0)