Skip to content

Commit cd319bf

Browse files
authored
fix: move get_asset_key_str to translator and define translator from config (#51)
1 parent 71f8f7a commit cd319bf

File tree

11 files changed

+278
-49
lines changed

11 files changed

+278
-49
lines changed

README.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ from dagster import (
2424
AssetExecutionContext,
2525
Definitions,
2626
)
27-
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator
27+
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
2828

2929
sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")
3030

31-
@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator())
31+
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
3232
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
3333
yield from sqlmesh.run(context)
3434

@@ -40,6 +40,34 @@ defs = Definitions(
4040
)
4141
```
4242

43+
## Advanced Usage
44+
45+
### Custom Translator
46+
47+
The translator is centrally configured and ensures consistency across all components. You can customize the translator by specifying a custom class in the config:
48+
49+
```python
50+
from dagster_sqlmesh import SQLMeshDagsterTranslator
51+
52+
class CustomSQLMeshTranslator(SQLMeshDagsterTranslator):
53+
def get_asset_key_str(self, fqn: str) -> str:
54+
# Custom asset key generation logic
55+
return f"custom_prefix__{super().get_asset_key_str(fqn)}"
56+
57+
# Configure with custom translator
58+
sqlmesh_config = SQLMeshContextConfig(
59+
path="/home/foo/sqlmesh_project",
60+
gateway="name-of-your-gateway",
61+
translator_class_name="your_module.CustomSQLMeshTranslator"
62+
)
63+
64+
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
65+
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
66+
yield from sqlmesh.run(context)
67+
```
68+
69+
This approach ensures that both the `SQLMeshResource` and the `@sqlmesh_assets` decorator use the same translator instance, preventing inconsistencies. The translator is created using `config.get_translator()` and passed to all components that need it, including the `DagsterSQLMeshEventHandler`.
70+
4371

4472
## Contributing
4573

dagster_sqlmesh/asset.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
ContextFactory,
1111
DagsterSQLMeshController,
1212
)
13-
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
1413
from dagster_sqlmesh.types import SQLMeshMultiAssetOptions
1514

1615
logger = logging.getLogger(__name__)
@@ -20,20 +19,18 @@ def sqlmesh_to_multi_asset_options(
2019
environment: str,
2120
config: SQLMeshContextConfig,
2221
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
23-
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
2422
) -> SQLMeshMultiAssetOptions:
2523
"""Converts sqlmesh project into a SQLMeshMultiAssetOptions object which is
2624
an intermediate representation of the SQLMesh project that can be used to
2725
create a dagster multi_asset definition."""
2826
controller = DagsterSQLMeshController.setup_with_config(
2927
config=config, context_factory=context_factory
3028
)
31-
if not dagster_sqlmesh_translator:
32-
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
29+
translator = config.get_translator()
3330

3431
conversion = controller.to_asset_outs(
3532
environment,
36-
translator=dagster_sqlmesh_translator,
33+
translator=translator,
3734
)
3835
return conversion
3936

@@ -74,7 +71,6 @@ def sqlmesh_assets(
7471
config: SQLMeshContextConfig,
7572
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
7673
name: str | None = None,
77-
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
7874
compute_kind: str = "sqlmesh",
7975
op_tags: t.Mapping[str, t.Any] | None = None,
8076
required_resource_keys: set[str] | None = None,
@@ -86,7 +82,6 @@ def sqlmesh_assets(
8682
environment=environment,
8783
config=config,
8884
context_factory=context_factory,
89-
dagster_sqlmesh_translator=dagster_sqlmesh_translator,
9085
)
9186

9287
return sqlmesh_asset_from_multi_asset_options(

dagster_sqlmesh/config.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1+
import inspect
2+
import typing as t
13
from dataclasses import dataclass
24
from pathlib import Path
3-
from typing import Any
45

56
from dagster import Config
67
from pydantic import Field
78
from sqlmesh.core.config import Config as MeshConfig
89
from sqlmesh.core.config.loader import load_configs
910

11+
if t.TYPE_CHECKING:
12+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
13+
1014

1115
@dataclass
1216
class ConfigOverride:
13-
config_as_dict: dict[str, Any]
17+
config_as_dict: dict[str, t.Any]
1418

15-
def dict(self) -> dict[str, Any]:
19+
def dict(self) -> dict[str, t.Any]:
1620
return self.config_as_dict
1721

1822

@@ -22,11 +26,56 @@ class SQLMeshContextConfig(Config):
2226
sqlmesh project define all the configuration in it's own directory which
2327
also ensures that configuration is consistent if running sqlmesh locally vs
2428
running via dagster.
29+
30+
The config also manages the translator class used for converting SQLMesh
31+
models to Dagster assets. You can specify a custom translator by setting
32+
the translator_class_name field to the fully qualified class name.
2533
"""
2634

2735
path: str
2836
gateway: str
29-
config_override: dict[str, Any] | None = Field(default_factory=lambda: None)
37+
config_override: dict[str, t.Any] | None = Field(default_factory=lambda: None)
38+
translator_class_name: str = Field(
39+
default="dagster_sqlmesh.translator.SQLMeshDagsterTranslator",
40+
description="Fully qualified class name of the SQLMesh Dagster translator to use"
41+
)
42+
43+
def get_translator(self) -> "SQLMeshDagsterTranslator":
44+
"""Get a translator instance using the configured class name.
45+
46+
Imports and validates the translator class, then creates a new instance.
47+
The class must inherit from SQLMeshDagsterTranslator.
48+
49+
Returns:
50+
SQLMeshDagsterTranslator: A new instance of the configured translator class
51+
52+
Raises:
53+
ValueError: If the imported object is not a class or does not inherit
54+
from SQLMeshDagsterTranslator
55+
"""
56+
from importlib import import_module
57+
58+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
59+
60+
module_name, class_name = self.translator_class_name.rsplit(".", 1)
61+
module = import_module(module_name)
62+
translator_class = getattr(module, class_name)
63+
64+
# Validate that the imported class inherits from SQLMeshDagsterTranslator
65+
if not inspect.isclass(translator_class):
66+
raise ValueError(
67+
f"'{self.translator_class_name}' is not a class. "
68+
f"Expected a class that inherits from SQLMeshDagsterTranslator."
69+
)
70+
71+
if not issubclass(translator_class, SQLMeshDagsterTranslator):
72+
raise ValueError(
73+
f"Translator class '{self.translator_class_name}' must inherit from "
74+
f"SQLMeshDagsterTranslator. Found class that inherits from: "
75+
f"{[base.__name__ for base in translator_class.__bases__]}"
76+
)
77+
78+
return translator_class()
3079

3180
@property
3281
def sqlmesh_config(self) -> MeshConfig:

dagster_sqlmesh/controller/dagster.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
SQLMeshModelDep,
1313
SQLMeshMultiAssetOptions,
1414
)
15-
from dagster_sqlmesh.utils import get_asset_key_str
1615

1716
logger = logging.getLogger(__name__)
1817

@@ -53,7 +52,7 @@ def to_asset_outs(
5352

5453
internal_asset_deps.add(dep_asset_key_str)
5554
else:
56-
table = get_asset_key_str(dep.fqn)
55+
table = translator.get_asset_key_str(dep.fqn)
5756
key = translator.get_asset_key(
5857
context, dep.fqn
5958
).to_user_string()
@@ -62,7 +61,7 @@ def to_asset_outs(
6261
# create an external dep
6362
deps_map[table] = translator.create_asset_dep(key=key)
6463

65-
model_key = get_asset_key_str(model.fqn)
64+
model_key = translator.get_asset_key_str(model.fqn)
6665
asset_outs[model_key] = translator.create_asset_out(
6766
model_key=model_key,
6867
asset_key=asset_key_str,

dagster_sqlmesh/resource.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
ContextFactory,
2626
)
2727
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
28-
from dagster_sqlmesh.utils import get_asset_key_str
28+
29+
if t.TYPE_CHECKING:
30+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
2931

3032
logger = logging.getLogger(__name__)
3133

@@ -329,6 +331,7 @@ def __init__(
329331
models_map: dict[str, Model],
330332
dag: DAG[t.Any],
331333
prefix: str,
334+
translator: "SQLMeshDagsterTranslator",
332335
is_testing: bool = False,
333336
materializations_enabled: bool = True,
334337
) -> None:
@@ -341,6 +344,7 @@ def __init__(
341344
models_map: A mapping of model names to their SQLMesh model instances.
342345
dag: The directed acyclic graph representing the SQLMesh models.
343346
prefix: A prefix to use for all asset keys generated by this handler.
347+
translator: The SQLMesh Dagster translator instance.
344348
is_testing: Whether the handler is being used in a testing context.
345349
materializations_enabled: Whether the handler is to generate
346350
materializations, this should be disabled if you with to run a
@@ -351,6 +355,7 @@ def __init__(
351355
self._prefix = prefix
352356
self._context = context
353357
self._logger = context.log
358+
self._translator = translator
354359
self._tracker = MaterializationTracker(
355360
sorted_dag=dag.sorted[:], logger=self._logger
356361
)
@@ -382,7 +387,7 @@ def notify_success(
382387
# If the model is not in models_map, we can skip any notification
383388
if model:
384389
# Passing model.fqn to get internal unique asset key
385-
output_key = get_asset_key_str(model.fqn)
390+
output_key = self._translator.get_asset_key_str(model.fqn)
386391
if self._is_testing:
387392
asset_key = dg.AssetKey(["testing", output_key])
388393
self._logger.warning(
@@ -491,15 +496,15 @@ def report_event(self, event: console.ConsoleEvent) -> None:
491496
log_context.info(
492497
"Snapshot progress complete",
493498
{
494-
"asset_key": get_asset_key_str(snapshot.model.name),
499+
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
495500
},
496501
)
497502
self._tracker.update_run(snapshot)
498503
else:
499504
log_context.info(
500505
"Snapshot progress update",
501506
{
502-
"asset_key": get_asset_key_str(snapshot.model.name),
507+
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
503508
"progress": f"{done}/{expected}",
504509
"duration_ms": duration_ms,
505510
},
@@ -687,11 +692,13 @@ def create_event_handler(
687692
is_testing: bool,
688693
materializations_enabled: bool,
689694
) -> DagsterSQLMeshEventHandler:
695+
translator = self.config.get_translator()
690696
return DagsterSQLMeshEventHandler(
691697
context=context,
692698
dag=dag,
693699
models_map=models_map,
694700
prefix=prefix,
701+
translator=translator,
695702
is_testing=is_testing,
696703
materializations_enabled=materializations_enabled,
697704
)
@@ -701,7 +708,7 @@ def _get_selected_models_from_context(
701708
) -> tuple[set[str], dict[str, Model], list[str] | None]:
702709
models_map = models.copy()
703710
try:
704-
selected_output_names = set(context.selected_output_names)
711+
selected_output_names = set(context.op_execution_context.selected_output_names)
705712
except (DagsterInvalidPropertyError, AttributeError) as e:
706713
# Special case for direct execution context when testing. This is related to:
707714
# https://github.com/dagster-io/dagster/issues/23633
@@ -711,10 +718,11 @@ def _get_selected_models_from_context(
711718
else:
712719
raise e
713720

721+
translator = self.config.get_translator()
714722
select_models: list[str] = []
715723
models_map = {}
716724
for key, model in models.items():
717-
if get_asset_key_str(model.fqn) in selected_output_names:
725+
if translator.get_asset_key_str(model.fqn) in selected_output_names:
718726
models_map[key] = model
719727
select_models.append(model.name)
720728
return (

dagster_sqlmesh/test_asset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
21
from dagster_sqlmesh.conftest import SQLMeshTestContext
32

43

54
def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestContext):
65
controller = sample_sqlmesh_test_context.create_controller()
7-
translator = SQLMeshDagsterTranslator()
8-
outs = controller.to_asset_outs("dev", translator)
6+
translator = sample_sqlmesh_test_context.context_config.get_translator()
7+
outs = controller.to_asset_outs("dev", translator=translator)
98
assert len(list(outs.deps)) == 1
109
assert len(outs.outs) == 10

dagster_sqlmesh/test_config.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
3+
from dagster_sqlmesh.config import SQLMeshContextConfig
4+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
5+
6+
7+
def test_get_translator_with_valid_class():
8+
"""Test that get_translator works with the default translator class."""
9+
config = SQLMeshContextConfig(path="/tmp/test", gateway="local")
10+
translator = config.get_translator()
11+
assert isinstance(translator, SQLMeshDagsterTranslator)
12+
13+
14+
def test_get_translator_with_non_class():
15+
"""Test that get_translator raises ValueError when pointing to a non-class."""
16+
config = SQLMeshContextConfig(
17+
path="/tmp/test",
18+
gateway="local",
19+
translator_class_name="sys.version"
20+
)
21+
22+
with pytest.raises(ValueError, match="is not a class"):
23+
config.get_translator()
24+
25+
26+
def test_get_translator_with_invalid_inheritance():
27+
"""Test that get_translator raises ValueError when class doesn't inherit from SQLMeshDagsterTranslator."""
28+
config = SQLMeshContextConfig(
29+
path="/tmp/test",
30+
gateway="local",
31+
translator_class_name="builtins.dict"
32+
)
33+
34+
with pytest.raises(ValueError, match="must inherit from SQLMeshDagsterTranslator"):
35+
config.get_translator()
36+
37+
38+
def test_get_translator_with_nonexistent_class():
39+
"""Test that get_translator raises AttributeError when class doesn't exist."""
40+
config = SQLMeshContextConfig(
41+
path="/tmp/test",
42+
gateway="local",
43+
translator_class_name="dagster_sqlmesh.translator.NonexistentClass"
44+
)
45+
46+
with pytest.raises(AttributeError):
47+
config.get_translator()
48+
49+
50+
class MockValidTranslator(SQLMeshDagsterTranslator):
51+
"""A mock translator for testing custom inheritance."""
52+
pass
53+
54+
55+
def test_get_translator_with_valid_custom_class():
56+
"""Test that get_translator works with custom classes that inherit from SQLMeshDagsterTranslator."""
57+
config = SQLMeshContextConfig(
58+
path="/tmp/test",
59+
gateway="local",
60+
translator_class_name=f"{__name__}.MockValidTranslator"
61+
)
62+
63+
translator = config.get_translator()
64+
assert isinstance(translator, SQLMeshDagsterTranslator)
65+
assert isinstance(translator, MockValidTranslator)

dagster_sqlmesh/testing/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def create_event_handler(self, *args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshE
8181
Returns:
8282
DagsterSQLMeshEventHandler: The created event handler.
8383
"""
84+
# Ensure translator is passed to the event handler factory
85+
kwargs['translator'] = self.config.get_translator()
8486
return self._event_handler_factory(*args, **kwargs)
8587

8688

0 commit comments

Comments
 (0)