Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ from dagster import (
AssetExecutionContext,
Definitions,
)
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource

sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")

@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator())
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
yield from sqlmesh.run(context)

Expand All @@ -40,6 +40,34 @@ defs = Definitions(
)
```

## Advanced Usage

### Custom Translator

The translator is centrally configured and ensures consistency across all components. You can customize the translator by specifying a custom class in the config:

```python
from dagster_sqlmesh import SQLMeshDagsterTranslator

class CustomSQLMeshTranslator(SQLMeshDagsterTranslator):
def get_asset_key_str(self, fqn: str) -> str:
# Custom asset key generation logic
return f"custom_prefix__{super().get_asset_key_str(fqn)}"

# Configure with custom translator
sqlmesh_config = SQLMeshContextConfig(
path="/home/foo/sqlmesh_project",
gateway="name-of-your-gateway",
translator_class_name="your_module.CustomSQLMeshTranslator"
)

@sqlmesh_assets(environment="dev", config=sqlmesh_config)
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
yield from sqlmesh.run(context)
```

This approach ensures that both the `SQLMeshResource` and the `@sqlmesh_assets` decorator use the same translator instance, preventing inconsistencies. The translator is created using `config.get_translator()` and passed to all components that need it, including the `DagsterSQLMeshEventHandler`.


## Contributing

Expand Down
9 changes: 2 additions & 7 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ContextFactory,
DagsterSQLMeshController,
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.types import SQLMeshMultiAssetOptions

logger = logging.getLogger(__name__)
Expand All @@ -20,20 +19,18 @@ def sqlmesh_to_multi_asset_options(
environment: str,
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
) -> SQLMeshMultiAssetOptions:
"""Converts sqlmesh project into a SQLMeshMultiAssetOptions object which is
an intermediate representation of the SQLMesh project that can be used to
create a dagster multi_asset definition."""
controller = DagsterSQLMeshController.setup_with_config(
config=config, context_factory=context_factory
)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
translator = config.get_translator()

conversion = controller.to_asset_outs(
environment,
translator=dagster_sqlmesh_translator,
translator=translator,
)
return conversion

Expand Down Expand Up @@ -74,7 +71,6 @@ def sqlmesh_assets(
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
name: str | None = None,
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
compute_kind: str = "sqlmesh",
op_tags: t.Mapping[str, t.Any] | None = None,
required_resource_keys: set[str] | None = None,
Expand All @@ -86,7 +82,6 @@ def sqlmesh_assets(
environment=environment,
config=config,
context_factory=context_factory,
dagster_sqlmesh_translator=dagster_sqlmesh_translator,
)

return sqlmesh_asset_from_multi_asset_options(
Expand Down
57 changes: 53 additions & 4 deletions dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import inspect
import typing as t
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from dagster import Config
from pydantic import Field
from sqlmesh.core.config import Config as MeshConfig
from sqlmesh.core.config.loader import load_configs

if t.TYPE_CHECKING:
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator


@dataclass
class ConfigOverride:
config_as_dict: dict[str, Any]
config_as_dict: dict[str, t.Any]

def dict(self) -> dict[str, Any]:
def dict(self) -> dict[str, t.Any]:
return self.config_as_dict


Expand All @@ -22,11 +26,56 @@ class SQLMeshContextConfig(Config):
sqlmesh project define all the configuration in it's own directory which
also ensures that configuration is consistent if running sqlmesh locally vs
running via dagster.

The config also manages the translator class used for converting SQLMesh
models to Dagster assets. You can specify a custom translator by setting
the translator_class_name field to the fully qualified class name.
"""

path: str
gateway: str
config_override: dict[str, Any] | None = Field(default_factory=lambda: None)
config_override: dict[str, t.Any] | None = Field(default_factory=lambda: None)
translator_class_name: str = Field(
default="dagster_sqlmesh.translator.SQLMeshDagsterTranslator",
description="Fully qualified class name of the SQLMesh Dagster translator to use"
)

def get_translator(self) -> "SQLMeshDagsterTranslator":
"""Get a translator instance using the configured class name.

Imports and validates the translator class, then creates a new instance.
The class must inherit from SQLMeshDagsterTranslator.

Returns:
SQLMeshDagsterTranslator: A new instance of the configured translator class

Raises:
ValueError: If the imported object is not a class or does not inherit
from SQLMeshDagsterTranslator
"""
from importlib import import_module

from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

module_name, class_name = self.translator_class_name.rsplit(".", 1)
module = import_module(module_name)
translator_class = getattr(module, class_name)

# Validate that the imported class inherits from SQLMeshDagsterTranslator
if not inspect.isclass(translator_class):
raise ValueError(
f"'{self.translator_class_name}' is not a class. "
f"Expected a class that inherits from SQLMeshDagsterTranslator."
)

if not issubclass(translator_class, SQLMeshDagsterTranslator):
raise ValueError(
f"Translator class '{self.translator_class_name}' must inherit from "
f"SQLMeshDagsterTranslator. Found class that inherits from: "
f"{[base.__name__ for base in translator_class.__bases__]}"
)

return translator_class()

@property
def sqlmesh_config(self) -> MeshConfig:
Expand Down
5 changes: 2 additions & 3 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
SQLMeshModelDep,
SQLMeshMultiAssetOptions,
)
from dagster_sqlmesh.utils import get_asset_key_str

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,7 +52,7 @@ def to_asset_outs(

internal_asset_deps.add(dep_asset_key_str)
else:
table = get_asset_key_str(dep.fqn)
table = translator.get_asset_key_str(dep.fqn)
key = translator.get_asset_key(
context, dep.fqn
).to_user_string()
Expand All @@ -62,7 +61,7 @@ def to_asset_outs(
# create an external dep
deps_map[table] = translator.create_asset_dep(key=key)

model_key = get_asset_key_str(model.fqn)
model_key = translator.get_asset_key_str(model.fqn)
asset_outs[model_key] = translator.create_asset_out(
model_key=model_key,
asset_key=asset_key_str,
Expand Down
20 changes: 14 additions & 6 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
ContextFactory,
)
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
from dagster_sqlmesh.utils import get_asset_key_str

if t.TYPE_CHECKING:
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -329,6 +331,7 @@ def __init__(
models_map: dict[str, Model],
dag: DAG[t.Any],
prefix: str,
translator: "SQLMeshDagsterTranslator",
is_testing: bool = False,
materializations_enabled: bool = True,
) -> None:
Expand All @@ -341,6 +344,7 @@ def __init__(
models_map: A mapping of model names to their SQLMesh model instances.
dag: The directed acyclic graph representing the SQLMesh models.
prefix: A prefix to use for all asset keys generated by this handler.
translator: The SQLMesh Dagster translator instance.
is_testing: Whether the handler is being used in a testing context.
materializations_enabled: Whether the handler is to generate
materializations, this should be disabled if you with to run a
Expand All @@ -351,6 +355,7 @@ def __init__(
self._prefix = prefix
self._context = context
self._logger = context.log
self._translator = translator
self._tracker = MaterializationTracker(
sorted_dag=dag.sorted[:], logger=self._logger
)
Expand Down Expand Up @@ -382,7 +387,7 @@ def notify_success(
# If the model is not in models_map, we can skip any notification
if model:
# Passing model.fqn to get internal unique asset key
output_key = get_asset_key_str(model.fqn)
output_key = self._translator.get_asset_key_str(model.fqn)
if self._is_testing:
asset_key = dg.AssetKey(["testing", output_key])
self._logger.warning(
Expand Down Expand Up @@ -491,15 +496,15 @@ def report_event(self, event: console.ConsoleEvent) -> None:
log_context.info(
"Snapshot progress complete",
{
"asset_key": get_asset_key_str(snapshot.model.name),
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
},
)
self._tracker.update_run(snapshot)
else:
log_context.info(
"Snapshot progress update",
{
"asset_key": get_asset_key_str(snapshot.model.name),
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
"progress": f"{done}/{expected}",
"duration_ms": duration_ms,
},
Expand Down Expand Up @@ -687,11 +692,13 @@ def create_event_handler(
is_testing: bool,
materializations_enabled: bool,
) -> DagsterSQLMeshEventHandler:
translator = self.config.get_translator()
return DagsterSQLMeshEventHandler(
context=context,
dag=dag,
models_map=models_map,
prefix=prefix,
translator=translator,
is_testing=is_testing,
materializations_enabled=materializations_enabled,
)
Expand All @@ -701,7 +708,7 @@ def _get_selected_models_from_context(
) -> tuple[set[str], dict[str, Model], list[str] | None]:
models_map = models.copy()
try:
selected_output_names = set(context.selected_output_names)
selected_output_names = set(context.op_execution_context.selected_output_names)
except (DagsterInvalidPropertyError, AttributeError) as e:
# Special case for direct execution context when testing. This is related to:
# https://github.com/dagster-io/dagster/issues/23633
Expand All @@ -711,10 +718,11 @@ def _get_selected_models_from_context(
else:
raise e

translator = self.config.get_translator()
select_models: list[str] = []
models_map = {}
for key, model in models.items():
if get_asset_key_str(model.fqn) in selected_output_names:
if translator.get_asset_key_str(model.fqn) in selected_output_names:
models_map[key] = model
select_models.append(model.name)
return (
Expand Down
5 changes: 2 additions & 3 deletions dagster_sqlmesh/test_asset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
from dagster_sqlmesh.conftest import SQLMeshTestContext


def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestContext):
controller = sample_sqlmesh_test_context.create_controller()
translator = SQLMeshDagsterTranslator()
outs = controller.to_asset_outs("dev", translator)
translator = sample_sqlmesh_test_context.context_config.get_translator()
outs = controller.to_asset_outs("dev", translator=translator)
assert len(list(outs.deps)) == 1
assert len(outs.outs) == 10
65 changes: 65 additions & 0 deletions dagster_sqlmesh/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator


def test_get_translator_with_valid_class():
"""Test that get_translator works with the default translator class."""
config = SQLMeshContextConfig(path="/tmp/test", gateway="local")
translator = config.get_translator()
assert isinstance(translator, SQLMeshDagsterTranslator)


def test_get_translator_with_non_class():
"""Test that get_translator raises ValueError when pointing to a non-class."""
config = SQLMeshContextConfig(
path="/tmp/test",
gateway="local",
translator_class_name="sys.version"
)

with pytest.raises(ValueError, match="is not a class"):
config.get_translator()


def test_get_translator_with_invalid_inheritance():
"""Test that get_translator raises ValueError when class doesn't inherit from SQLMeshDagsterTranslator."""
config = SQLMeshContextConfig(
path="/tmp/test",
gateway="local",
translator_class_name="builtins.dict"
)

with pytest.raises(ValueError, match="must inherit from SQLMeshDagsterTranslator"):
config.get_translator()


def test_get_translator_with_nonexistent_class():
"""Test that get_translator raises AttributeError when class doesn't exist."""
config = SQLMeshContextConfig(
path="/tmp/test",
gateway="local",
translator_class_name="dagster_sqlmesh.translator.NonexistentClass"
)

with pytest.raises(AttributeError):
config.get_translator()


class MockValidTranslator(SQLMeshDagsterTranslator):
"""A mock translator for testing custom inheritance."""
pass


def test_get_translator_with_valid_custom_class():
"""Test that get_translator works with custom classes that inherit from SQLMeshDagsterTranslator."""
config = SQLMeshContextConfig(
path="/tmp/test",
gateway="local",
translator_class_name=f"{__name__}.MockValidTranslator"
)

translator = config.get_translator()
assert isinstance(translator, SQLMeshDagsterTranslator)
assert isinstance(translator, MockValidTranslator)
2 changes: 2 additions & 0 deletions dagster_sqlmesh/testing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def create_event_handler(self, *args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshE
Returns:
DagsterSQLMeshEventHandler: The created event handler.
"""
# Ensure translator is passed to the event handler factory
kwargs['translator'] = self.config.get_translator()
return self._event_handler_factory(*args, **kwargs)


Expand Down
Loading