Skip to content

Commit dee692e

Browse files
committed
Fix: treat all instances of macro variables as case-insensitive
1 parent dc302eb commit dee692e

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

sqlmesh/core/macros.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,23 @@ def _macro_str_replace(text: str) -> str:
128128
return f"self.template({text}, locals())"
129129

130130

131+
class CaseInsensitiveMapping(dict):
132+
def __init__(self, data: t.Dict[str, t.Any]) -> None:
133+
super().__init__(data)
134+
135+
self._lower = {k.lower(): v for k, v in data.items()}
136+
137+
def __getitem__(self, key: str) -> t.Any:
138+
if key in self:
139+
return super().__getitem__(key)
140+
return self._lower[key.lower()]
141+
142+
def get(self, key: str, default: t.Any = None) -> t.Any:
143+
if key in self:
144+
return super().get(key, default)
145+
return self._lower.get(key.lower(), default)
146+
147+
131148
class MacroDialect(Python):
132149
class Generator(Python.Generator):
133150
TRANSFORMS = {
@@ -222,7 +239,7 @@ def __init__(
222239
for var_name, var_value in value.items()
223240
}
224241

225-
self.locals[k] = value
242+
self.locals[k.lower()] = value
226243

227244
def send(
228245
self, name: str, *args: t.Any, **kwargs: t.Any
@@ -256,14 +273,18 @@ def evaluate_macros(
256273
changed = True
257274
variables = self.variables
258275

259-
if node.name not in self.locals and node.name.lower() not in variables:
276+
# This makes all variables case-insensitive, e.g. @X is the same as @x. We do this
277+
# for consistency, since `variables` and `blueprint_variables` are normalized.
278+
var_name = node.name.lower()
279+
280+
if var_name not in self.locals and var_name not in variables:
260281
if not isinstance(node.parent, StagedFilePath):
261-
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")
282+
raise SQLMeshError(f"Macro variable '{var_name}' is undefined.")
262283

263284
return node
264285

265286
# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
266-
value = self.locals.get(node.name, variables.get(node.name.lower()))
287+
value = self.locals.get(var_name, variables.get(var_name))
267288
if isinstance(value, list):
268289
return exp.convert(
269290
tuple(
@@ -313,11 +334,11 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
313334
"""
314335
# We try to convert all variables into sqlglot expressions because they're going to be converted
315336
# into strings; in sql we don't convert strings because that would result in adding quotes
316-
mapping = {
337+
base_mapping = {
317338
k: convert_sql(v, self.dialect)
318339
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
319340
}
320-
return MacroStrTemplate(str(text)).safe_substitute(mapping)
341+
return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping))
321342

322343
def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
323344
if isinstance(node, MacroDef):
@@ -327,7 +348,7 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
327348
args[0] if len(args) == 1 else exp.Tuple(expressions=list(args))
328349
)
329350
else:
330-
self.locals[node.name] = self.transform(node.expression)
351+
self.locals[node.name.lower()] = self.transform(node.expression)
331352
return node
332353

333354
if isinstance(node, (MacroSQL, MacroStrReplace)):

tests/core/test_macros.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,9 @@ def test_macro_with_spaces():
11121112

11131113
for sql, expected in (
11141114
("@x", '"a b"'),
1115+
("@X", '"a b"'),
11151116
("@{x}", '"a b"'),
1117+
("@{X}", '"a b"'),
11161118
("a_@x", '"a_a b"'),
11171119
("a.@x", 'a."a b"'),
11181120
("@y", "'a b'"),
@@ -1121,6 +1123,7 @@ def test_macro_with_spaces():
11211123
("a.@{y}", 'a."a b"'),
11221124
("@z", 'a."b c"'),
11231125
("d.@z", 'd.a."b c"'),
1126+
("@'test_@{X}_suffix'", "'test_a b_suffix'"),
11241127
):
11251128
assert evaluator.transform(parse_one(sql)).sql() == expected
11261129

0 commit comments

Comments
 (0)