diff --git a/docs/examples/sqlmesh_cli_crash_course.md b/docs/examples/sqlmesh_cli_crash_course.md index 0bf5780f12..c7fdf0a5f2 100644 --- a/docs/examples/sqlmesh_cli_crash_course.md +++ b/docs/examples/sqlmesh_cli_crash_course.md @@ -675,6 +675,8 @@ This is a great way to catch SQL issues before wasting runtime in your data ware ```bash sqlmesh lint + # or apply fixes automatically + sqlmesh lint --fix ``` === "Tobiko Cloud" diff --git a/docs/guides/linter.md b/docs/guides/linter.md index 22cc5077b8..101eb4d213 100644 --- a/docs/guides/linter.md +++ b/docs/guides/linter.md @@ -100,7 +100,7 @@ Place a rule's code in the project's `linter/` directory. SQLMesh will load all If the rule is specified in the project's [configuration file](#applying-linting-rules), SQLMesh will run it when: - A plan is created during `sqlmesh plan` -- The command `sqlmesh lint` is ran +- The command `sqlmesh lint` is ran. Add `--fix` to automatically apply available fixes and fail if errors remain. SQLMesh will error if a model violates the rule, informing you which model(s) violated the rule. In this example, `full_model.sql` violated the `NoMissingOwner` rule, essentially halting execution: diff --git a/docs/reference/cli.md b/docs/reference/cli.md index b6877962ab..4d25d8df94 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -636,6 +636,7 @@ Usage: sqlmesh lint [OPTIONS] Options: --model TEXT A model to lint. Multiple models can be linted. If no models are specified, every model will be linted. + --fix Apply fixes for lint errors. Fails if errors remain after fixes are applied. --help Show this message and exit. ``` diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 8982efc9f8..23de79dd30 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -1201,15 +1201,21 @@ def environments(obj: Context) -> None: multiple=True, help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.", ) +@click.option( + "--fix", + is_flag=True, + help="Apply fixes for lint errors. Fails if errors remain after fixes are applied.", +) @click.pass_obj @error_handler @cli_analytics def lint( obj: Context, models: t.Iterator[str], + fix: bool, ) -> None: """Run the linter for the target model(s).""" - obj.lint_models(models) + obj.lint_models(models, fix=fix) @cli.group(no_args_is_help=True) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index c0d9b21ff8..9983b005e9 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -78,6 +78,7 @@ from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements from sqlmesh.core.loader import Loader from sqlmesh.core.linter.definition import AnnotatedRuleViolation, Linter +from sqlmesh.core.linter.rule import TextEdit, Position from sqlmesh.core.linter.rules import BUILTIN_RULES from sqlmesh.core.macros import ExecutableOrMacro, macro from sqlmesh.core.metric import Metric, rewrite @@ -3099,6 +3100,7 @@ def lint_models( self, models: t.Optional[t.Iterable[t.Union[str, Model]]] = None, raise_on_error: bool = True, + fix: bool = False, ) -> t.List[AnnotatedRuleViolation]: found_error = False @@ -3116,6 +3118,11 @@ def lint_models( found_error = True all_violations.extend(violations) + if fix: + self._apply_fixes(all_violations) + self.refresh() + return self.lint_models(models, raise_on_error=raise_on_error, fix=False) + if raise_on_error and found_error: raise LinterError( "Linter detected errors in the code. Please fix them before proceeding." @@ -3123,6 +3130,33 @@ def lint_models( return all_violations + def _apply_fixes(self, violations: t.List[AnnotatedRuleViolation]) -> None: + edits_by_file: t.Dict[Path, t.List[TextEdit]] = {} + for violation in violations: + for fix in violation.fixes: + for create in fix.create_files: + create.path.parent.mkdir(parents=True, exist_ok=True) + create.path.write_text(create.text, encoding="utf-8") + for edit in fix.edits: + edits_by_file.setdefault(edit.path, []).append(edit) + + for path, edits in edits_by_file.items(): + content = path.read_text(encoding="utf-8") + lines = content.splitlines(keepends=True) + + def _offset(pos: Position) -> int: + return sum(len(lines[i]) for i in range(pos.line)) + pos.character + + for edit in sorted( + edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True + ): + start = _offset(edit.range.start) + end = _offset(edit.range.end) + content = content[:start] + edit.new_text + content[end:] + lines = content.splitlines(keepends=True) + + path.write_text(content, encoding="utf-8") + def load_model_tests( self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None ) -> t.List[ModelTestMetadata]: diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index d1f792dc28..ec9e221ed0 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1328,6 +1328,61 @@ def test_lint(runner, tmp_path): assert result.exit_code == 1 +def test_lint_fix(runner, tmp_path): + create_example_project(tmp_path) + + with open(tmp_path / "config.yaml", "a", encoding="utf-8") as f: + f.write( + """linter: + enabled: True + rules: ["noselectstar"] +""" + ) + + model_path = tmp_path / "models" / "incremental_model.sql" + with open(model_path, "r", encoding="utf-8") as f: + content = f.read() + content = content.replace( + "SELECT\n id,\n item_id,\n event_date,\nFROM", + "SELECT *\nFROM", + ) + with open(model_path, "w", encoding="utf-8") as f: + f.write(content) + + result = runner.invoke(cli, ["--paths", tmp_path, "lint", "--fix"]) + assert result.exit_code == 0 + with open(model_path, "r", encoding="utf-8") as f: + assert "SELECT *" not in f.read() + + +def test_lint_fix_unfixable_error(runner, tmp_path): + create_example_project(tmp_path) + + with open(tmp_path / "config.yaml", "a", encoding="utf-8") as f: + f.write( + """linter: + enabled: True + rules: ["noselectstar", "nomissingaudits"] +""" + ) + + model_path = tmp_path / "models" / "incremental_model.sql" + with open(model_path, "r", encoding="utf-8") as f: + content = f.read() + content = content.replace( + "SELECT\n id,\n item_id,\n event_date,\nFROM", + "SELECT *\nFROM", + ) + with open(model_path, "w", encoding="utf-8") as f: + f.write(content) + + result = runner.invoke(cli, ["--paths", tmp_path, "lint", "--fix"]) + assert result.exit_code == 1 + assert "nomissingaudits" in result.output + with open(model_path, "r", encoding="utf-8") as f: + assert "SELECT *" not in f.read() + + def test_state_export(runner: CliRunner, tmp_path: Path) -> None: create_example_project(tmp_path)