Skip to content

Commit 8751ee5

Browse files
authored
Feat(dbt_cli): Add support for --vars (#5205)
1 parent a759712 commit 8751ee5

File tree

8 files changed

+117
-22
lines changed

8 files changed

+117
-22
lines changed

sqlmesh/core/config/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def load_config_from_paths(
176176
project_root=dbt_project_file.parent,
177177
dbt_profile_name=kwargs.pop("profile", None),
178178
dbt_target_name=kwargs.pop("target", None),
179+
variables=variables,
179180
)
180181
if type(dbt_python_config) != config_type:
181182
dbt_python_config = convert_config_type(dbt_python_config, config_type)

sqlmesh_dbt/cli.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,31 @@
44
from sqlmesh_dbt.operations import DbtOperations, create
55
from sqlmesh_dbt.error import cli_global_error_handler
66
from pathlib import Path
7+
from sqlmesh_dbt.options import YamlParamType
8+
import functools
79

810

9-
def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
10-
if not isinstance(ctx.obj, DbtOperations):
11+
def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations:
12+
if not isinstance(ctx.obj, functools.partial):
1113
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")
12-
return ctx.obj
14+
15+
dbt_operations = ctx.obj(vars=vars)
16+
17+
if not isinstance(dbt_operations, DbtOperations):
18+
raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}")
19+
20+
@ctx.call_on_close
21+
def _cleanup() -> None:
22+
dbt_operations.close()
23+
24+
return dbt_operations
25+
26+
27+
vars_option = click.option(
28+
"--vars",
29+
type=YamlParamType(),
30+
help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'",
31+
)
1332

1433

1534
select_option = click.option(
@@ -40,10 +59,15 @@ def dbt(
4059
# we dont need to import sqlmesh/load the project for CLI help
4160
return
4261

43-
# TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project
44-
ctx.obj = create(project_dir=Path.cwd(), profile=profile, target=target)
62+
# we have a partially applied function here because subcommands might set extra options like --vars
63+
# that need to be known before we attempt to load the project
64+
ctx.obj = functools.partial(create, project_dir=Path.cwd(), profile=profile, target=target)
4565

4666
if not ctx.invoked_subcommand:
67+
if profile or target:
68+
# trigger a project load to validate the specified profile / target
69+
ctx.obj()
70+
4771
click.echo(
4872
f"No command specified. Run `{ctx.info_name} --help` to see the available commands."
4973
)
@@ -57,19 +81,21 @@ def dbt(
5781
"--full-refresh",
5882
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
5983
)
84+
@vars_option
6085
@click.pass_context
61-
def run(ctx: click.Context, **kwargs: t.Any) -> None:
86+
def run(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None:
6287
"""Compile SQL and execute against the current target database."""
63-
_get_dbt_operations(ctx).run(**kwargs)
88+
_get_dbt_operations(ctx, vars).run(**kwargs)
6489

6590

6691
@dbt.command(name="list")
6792
@select_option
6893
@exclude_option
94+
@vars_option
6995
@click.pass_context
70-
def list_(ctx: click.Context, **kwargs: t.Any) -> None:
96+
def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None:
7197
"""List the resources in your project"""
72-
_get_dbt_operations(ctx).list_(**kwargs)
98+
_get_dbt_operations(ctx, vars).list_(**kwargs)
7399

74100

75101
@dbt.command(name="ls", hidden=True) # hidden alias for list

sqlmesh_dbt/error.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,5 @@ def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any:
2525
sys.exit(1)
2626
else:
2727
raise
28-
finally:
29-
context_or_obj = args[0]
30-
sqlmesh_context = (
31-
context_or_obj.obj if isinstance(context_or_obj, click.Context) else context_or_obj
32-
)
33-
if sqlmesh_context is not None:
34-
# important to import this only if a context was created
35-
# otherwise something like `sqlmesh_dbt run --help` will trigger this import because it's in the finally: block
36-
from sqlmesh import Context
37-
38-
if isinstance(sqlmesh_context, Context):
39-
sqlmesh_context.close()
4028

4129
return wrapper

sqlmesh_dbt/operations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,15 @@ def console(self) -> DbtCliConsole:
7676

7777
return console
7878

79+
def close(self) -> None:
80+
self.context.close()
81+
7982

8083
def create(
8184
project_dir: t.Optional[Path] = None,
8285
profile: t.Optional[str] = None,
8386
target: t.Optional[str] = None,
87+
vars: t.Optional[t.Dict[str, t.Any]] = None,
8488
debug: bool = False,
8589
) -> DbtOperations:
8690
with Progress(transient=True) as progress:
@@ -104,7 +108,7 @@ def create(
104108

105109
sqlmesh_context = Context(
106110
paths=[project_dir],
107-
config_loader_kwargs=dict(profile=profile, target=target),
111+
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
108112
load=True,
109113
)
110114

sqlmesh_dbt/options.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import typing as t
2+
import click
3+
from click.core import Context, Parameter
4+
5+
6+
class YamlParamType(click.ParamType):
7+
name = "yaml"
8+
9+
def convert(
10+
self, value: t.Any, param: t.Optional[Parameter], ctx: t.Optional[Context]
11+
) -> t.Any:
12+
if not isinstance(value, str):
13+
self.fail(f"Input value '{value}' should be a string", param, ctx)
14+
15+
from sqlmesh.utils import yaml
16+
17+
try:
18+
parsed = yaml.load(source=value, render_jinja=False)
19+
except:
20+
self.fail(f"String '{value}' is not valid YAML", param, ctx)
21+
22+
if not isinstance(parsed, dict):
23+
self.fail(f"String '{value}' did not evaluate to a dict, got: {parsed}", param, ctx)
24+
25+
return parsed

tests/dbt/cli/test_list.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,17 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..
4646
assert "main.orders" not in result.output
4747
assert "main.stg_payments" not in result.output
4848
assert "main.raw_orders" not in result.output
49+
50+
51+
def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
52+
(jaffle_shop_duckdb / "models" / "aliased_model.sql").write_text("""
53+
{{ config(alias='model_' + var('foo')) }}
54+
select 1
55+
""")
56+
57+
result = invoke_cli(["list", "--vars", "foo: bar"])
58+
59+
assert result.exit_code == 0
60+
assert not result.exception
61+
62+
assert "model_bar" in result.output

tests/dbt/cli/test_operations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,17 @@ def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path):
6969

7070
assert dbt_project.context.profile_name == "jaffle_shop"
7171
assert dbt_project.context.target_name == "dev"
72+
73+
74+
def test_create_can_set_project_variables(jaffle_shop_duckdb: Path):
75+
(jaffle_shop_duckdb / "models" / "test_model.sql").write_text("""
76+
select '{{ var('foo') }}' as a
77+
""")
78+
79+
dbt_project = create(vars={"foo": "bar"})
80+
assert dbt_project.context.config.variables["foo"] == "bar"
81+
82+
test_model = dbt_project.context.models['"jaffle_shop"."main"."test_model"']
83+
query = test_model.render_query()
84+
assert query is not None
85+
assert query.sql() == "SELECT 'bar' AS \"a\""

tests/dbt/cli/test_options.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import typing as t
2+
import pytest
3+
from sqlmesh_dbt.options import YamlParamType
4+
from click.exceptions import BadParameter
5+
6+
7+
@pytest.mark.parametrize(
8+
"input,expected",
9+
[
10+
(1, BadParameter("Input value '1' should be a string")),
11+
("", BadParameter("String '' is not valid YAML")),
12+
("['a', 'b']", BadParameter("String.*did not evaluate to a dict, got.*")),
13+
("foo: bar", {"foo": "bar"}),
14+
('{"key": "value", "date": 20180101}', {"key": "value", "date": 20180101}),
15+
("{key: value, date: 20180101}", {"key": "value", "date": 20180101}),
16+
],
17+
)
18+
def test_yaml_param_type(input: str, expected: t.Union[BadParameter, t.Dict[str, t.Any]]):
19+
if isinstance(expected, BadParameter):
20+
with pytest.raises(BadParameter, match=expected.message):
21+
YamlParamType().convert(input, None, None)
22+
else:
23+
assert YamlParamType().convert(input, None, None) == expected

0 commit comments

Comments
 (0)