Skip to content

Commit 99f7e6b

Browse files
committed
fix: move get_asset_key_str to translator and define translator from config
1 parent 71f8f7a commit 99f7e6b

File tree

12 files changed

+289
-60
lines changed

12 files changed

+289
-60
lines changed

README.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ from dagster import (
2424
AssetExecutionContext,
2525
Definitions,
2626
)
27-
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator
27+
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
2828

2929
sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")
3030

31-
@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator())
31+
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
3232
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
3333
yield from sqlmesh.run(context)
3434

@@ -40,6 +40,34 @@ defs = Definitions(
4040
)
4141
```
4242

43+
## Advanced Usage
44+
45+
### Custom Translator
46+
47+
The translator is centrally configured and ensures consistency across all components. You can customize the translator by specifying a custom class in the config:
48+
49+
```python
50+
from dagster_sqlmesh import SQLMeshDagsterTranslator
51+
52+
class CustomSQLMeshTranslator(SQLMeshDagsterTranslator):
53+
def get_asset_key_str(self, fqn: str) -> str:
54+
# Custom asset key generation logic
55+
return f"custom_prefix__{super().get_asset_key_str(fqn)}"
56+
57+
# Configure with custom translator
58+
sqlmesh_config = SQLMeshContextConfig(
59+
path="/home/foo/sqlmesh_project",
60+
gateway="name-of-your-gateway",
61+
translator_class_name="your_module.CustomSQLMeshTranslator"
62+
)
63+
64+
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
65+
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
66+
yield from sqlmesh.run(context)
67+
```
68+
69+
This approach ensures that both the `SQLMeshResource` and the `@sqlmesh_assets` decorator use the same translator instance, preventing inconsistencies. The translator is created using `config.get_translator()` and passed to all components that need it, including the `DagsterSQLMeshEventHandler`.
70+
4371

4472
## Contributing
4573

dagster_sqlmesh/asset.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
ContextFactory,
1111
DagsterSQLMeshController,
1212
)
13-
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
1413
from dagster_sqlmesh.types import SQLMeshMultiAssetOptions
1514

1615
logger = logging.getLogger(__name__)
@@ -20,20 +19,18 @@ def sqlmesh_to_multi_asset_options(
2019
environment: str,
2120
config: SQLMeshContextConfig,
2221
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
23-
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
2422
) -> SQLMeshMultiAssetOptions:
2523
"""Converts sqlmesh project into a SQLMeshMultiAssetOptions object which is
2624
an intermediate representation of the SQLMesh project that can be used to
2725
create a dagster multi_asset definition."""
2826
controller = DagsterSQLMeshController.setup_with_config(
2927
config=config, context_factory=context_factory
3028
)
31-
if not dagster_sqlmesh_translator:
32-
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
29+
dagster_sqlmesh_translator = config.get_translator()
3330

3431
conversion = controller.to_asset_outs(
3532
environment,
36-
translator=dagster_sqlmesh_translator,
33+
dagster_sqlmesh_translator=dagster_sqlmesh_translator,
3734
)
3835
return conversion
3936

@@ -74,7 +71,6 @@ def sqlmesh_assets(
7471
config: SQLMeshContextConfig,
7572
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
7673
name: str | None = None,
77-
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
7874
compute_kind: str = "sqlmesh",
7975
op_tags: t.Mapping[str, t.Any] | None = None,
8076
required_resource_keys: set[str] | None = None,
@@ -86,7 +82,6 @@ def sqlmesh_assets(
8682
environment=environment,
8783
config=config,
8884
context_factory=context_factory,
89-
dagster_sqlmesh_translator=dagster_sqlmesh_translator,
9085
)
9186

9287
return sqlmesh_asset_from_multi_asset_options(

dagster_sqlmesh/config.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import inspect
12
from dataclasses import dataclass
23
from pathlib import Path
3-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
45

56
from dagster import Config
67
from pydantic import Field
78
from sqlmesh.core.config import Config as MeshConfig
89
from sqlmesh.core.config.loader import load_configs
910

11+
if TYPE_CHECKING:
12+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
13+
1014

1115
@dataclass
1216
class ConfigOverride:
@@ -22,11 +26,56 @@ class SQLMeshContextConfig(Config):
2226
sqlmesh project define all the configuration in it's own directory which
2327
also ensures that configuration is consistent if running sqlmesh locally vs
2428
running via dagster.
29+
30+
The config also manages the translator class used for converting SQLMesh
31+
models to Dagster assets. You can specify a custom translator by setting
32+
the translator_class_name field to the fully qualified class name.
2533
"""
2634

2735
path: str
2836
gateway: str
2937
config_override: dict[str, Any] | None = Field(default_factory=lambda: None)
38+
translator_class_name: str = Field(
39+
default="dagster_sqlmesh.translator.SQLMeshDagsterTranslator",
40+
description="Fully qualified class name of the SQLMesh Dagster translator to use"
41+
)
42+
43+
def get_translator(self) -> "SQLMeshDagsterTranslator":
44+
"""Get a translator instance using the configured class name.
45+
46+
Imports and validates the translator class, then creates a new instance.
47+
The class must inherit from SQLMeshDagsterTranslator.
48+
49+
Returns:
50+
SQLMeshDagsterTranslator: A new instance of the configured translator class
51+
52+
Raises:
53+
ValueError: If the imported object is not a class or does not inherit
54+
from SQLMeshDagsterTranslator
55+
"""
56+
from importlib import import_module
57+
58+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
59+
60+
module_name, class_name = self.translator_class_name.rsplit(".", 1)
61+
module = import_module(module_name)
62+
translator_class = getattr(module, class_name)
63+
64+
# Validate that the imported class inherits from SQLMeshDagsterTranslator
65+
if not inspect.isclass(translator_class):
66+
raise ValueError(
67+
f"'{self.translator_class_name}' is not a class. "
68+
f"Expected a class that inherits from SQLMeshDagsterTranslator."
69+
)
70+
71+
if not issubclass(translator_class, SQLMeshDagsterTranslator):
72+
raise ValueError(
73+
f"Translator class '{self.translator_class_name}' must inherit from "
74+
f"SQLMeshDagsterTranslator. Found class that inherits from: "
75+
f"{[base.__name__ for base in translator_class.__bases__]}"
76+
)
77+
78+
return translator_class()
3079

3180
@property
3281
def sqlmesh_config(self) -> MeshConfig:

dagster_sqlmesh/controller/dagster.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
SQLMeshModelDep,
1313
SQLMeshMultiAssetOptions,
1414
)
15-
from dagster_sqlmesh.utils import get_asset_key_str
1615

1716
logger = logging.getLogger(__name__)
1817

@@ -23,7 +22,7 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]):
2322
def to_asset_outs(
2423
self,
2524
environment: str,
26-
translator: SQLMeshDagsterTranslator,
25+
dagster_sqlmesh_translator: SQLMeshDagsterTranslator,
2726
) -> SQLMeshMultiAssetOptions:
2827
"""Loads all the asset outs of the current sqlmesh environment. If a
2928
cache is provided, it will be tried first to load the asset outs."""
@@ -36,42 +35,42 @@ def to_asset_outs(
3635
context = instance.context
3736

3837
for model, deps in instance.non_external_models_dag():
39-
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
38+
asset_key = dagster_sqlmesh_translator.get_asset_key(context=context, fqn=model.fqn)
4039
asset_key_str = asset_key.to_user_string()
4140
model_deps = [
4241
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
4342
for dep in deps
4443
]
4544
internal_asset_deps: set[str] = set()
46-
asset_tags = translator.get_tags(context, model)
45+
asset_tags = dagster_sqlmesh_translator.get_tags(context, model)
4746

4847
for dep in model_deps:
4948
if dep.model:
50-
dep_asset_key_str = translator.get_asset_key(
49+
dep_asset_key_str = dagster_sqlmesh_translator.get_asset_key(
5150
context, dep.model.fqn
5251
).to_user_string()
5352

5453
internal_asset_deps.add(dep_asset_key_str)
5554
else:
56-
table = get_asset_key_str(dep.fqn)
57-
key = translator.get_asset_key(
55+
table = dagster_sqlmesh_translator.get_asset_key_str(dep.fqn)
56+
key = dagster_sqlmesh_translator.get_asset_key(
5857
context, dep.fqn
5958
).to_user_string()
6059
internal_asset_deps.add(key)
6160

6261
# create an external dep
63-
deps_map[table] = translator.create_asset_dep(key=key)
62+
deps_map[table] = dagster_sqlmesh_translator.create_asset_dep(key=key)
6463

65-
model_key = get_asset_key_str(model.fqn)
66-
asset_outs[model_key] = translator.create_asset_out(
64+
model_key = dagster_sqlmesh_translator.get_asset_key_str(model.fqn)
65+
asset_outs[model_key] = dagster_sqlmesh_translator.create_asset_out(
6766
model_key=model_key,
6867
asset_key=asset_key_str,
6968
tags=asset_tags,
7069
is_required=False,
71-
group_name=translator.get_group_name(context, model),
70+
group_name=dagster_sqlmesh_translator.get_group_name(context, model),
7271
kinds={
7372
"sqlmesh",
74-
translator.get_context_dialect(context).lower(),
73+
dagster_sqlmesh_translator.get_context_dialect(context).lower(),
7574
},
7675
)
7776
internal_asset_deps_map[model_key] = internal_asset_deps

dagster_sqlmesh/resource.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
ContextFactory,
2626
)
2727
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
28-
from dagster_sqlmesh.utils import get_asset_key_str
28+
29+
if t.TYPE_CHECKING:
30+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
2931

3032
logger = logging.getLogger(__name__)
3133

@@ -329,6 +331,7 @@ def __init__(
329331
models_map: dict[str, Model],
330332
dag: DAG[t.Any],
331333
prefix: str,
334+
dagster_sqlmesh_translator: "SQLMeshDagsterTranslator",
332335
is_testing: bool = False,
333336
materializations_enabled: bool = True,
334337
) -> None:
@@ -341,6 +344,7 @@ def __init__(
341344
models_map: A mapping of model names to their SQLMesh model instances.
342345
dag: The directed acyclic graph representing the SQLMesh models.
343346
prefix: A prefix to use for all asset keys generated by this handler.
347+
dagster_sqlmesh_translator: The SQLMesh Dagster translator instance.
344348
is_testing: Whether the handler is being used in a testing context.
345349
materializations_enabled: Whether the handler is to generate
346350
materializations, this should be disabled if you with to run a
@@ -351,6 +355,7 @@ def __init__(
351355
self._prefix = prefix
352356
self._context = context
353357
self._logger = context.log
358+
self._translator = dagster_sqlmesh_translator
354359
self._tracker = MaterializationTracker(
355360
sorted_dag=dag.sorted[:], logger=self._logger
356361
)
@@ -382,7 +387,7 @@ def notify_success(
382387
# If the model is not in models_map, we can skip any notification
383388
if model:
384389
# Passing model.fqn to get internal unique asset key
385-
output_key = get_asset_key_str(model.fqn)
390+
output_key = self._translator.get_asset_key_str(model.fqn)
386391
if self._is_testing:
387392
asset_key = dg.AssetKey(["testing", output_key])
388393
self._logger.warning(
@@ -491,15 +496,15 @@ def report_event(self, event: console.ConsoleEvent) -> None:
491496
log_context.info(
492497
"Snapshot progress complete",
493498
{
494-
"asset_key": get_asset_key_str(snapshot.model.name),
499+
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
495500
},
496501
)
497502
self._tracker.update_run(snapshot)
498503
else:
499504
log_context.info(
500505
"Snapshot progress update",
501506
{
502-
"asset_key": get_asset_key_str(snapshot.model.name),
507+
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
503508
"progress": f"{done}/{expected}",
504509
"duration_ms": duration_ms,
505510
},
@@ -687,11 +692,13 @@ def create_event_handler(
687692
is_testing: bool,
688693
materializations_enabled: bool,
689694
) -> DagsterSQLMeshEventHandler:
695+
translator = self.config.get_translator()
690696
return DagsterSQLMeshEventHandler(
691697
context=context,
692698
dag=dag,
693699
models_map=models_map,
694700
prefix=prefix,
701+
dagster_sqlmesh_translator=translator,
695702
is_testing=is_testing,
696703
materializations_enabled=materializations_enabled,
697704
)
@@ -701,7 +708,7 @@ def _get_selected_models_from_context(
701708
) -> tuple[set[str], dict[str, Model], list[str] | None]:
702709
models_map = models.copy()
703710
try:
704-
selected_output_names = set(context.selected_output_names)
711+
selected_output_names = set(context.op_execution_context.selected_output_names)
705712
except (DagsterInvalidPropertyError, AttributeError) as e:
706713
# Special case for direct execution context when testing. This is related to:
707714
# https://github.com/dagster-io/dagster/issues/23633
@@ -711,10 +718,11 @@ def _get_selected_models_from_context(
711718
else:
712719
raise e
713720

721+
translator = self.config.get_translator()
714722
select_models: list[str] = []
715723
models_map = {}
716724
for key, model in models.items():
717-
if get_asset_key_str(model.fqn) in selected_output_names:
725+
if translator.get_asset_key_str(model.fqn) in selected_output_names:
718726
models_map[key] = model
719727
select_models.append(model.name)
720728
return (

dagster_sqlmesh/test_asset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
21
from dagster_sqlmesh.conftest import SQLMeshTestContext
32

43

54
def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestContext):
65
controller = sample_sqlmesh_test_context.create_controller()
7-
translator = SQLMeshDagsterTranslator()
8-
outs = controller.to_asset_outs("dev", translator)
6+
translator = sample_sqlmesh_test_context.context_config.get_translator()
7+
outs = controller.to_asset_outs("dev", dagster_sqlmesh_translator=translator)
98
assert len(list(outs.deps)) == 1
109
assert len(outs.outs) == 10

0 commit comments

Comments
 (0)