@@ -128,6 +128,23 @@ def _macro_str_replace(text: str) -> str:
128
128
return f"self.template({ text } , locals())"
129
129
130
130
131
+ class CaseInsensitiveMapping (dict ):
132
+ def __init__ (self , data : t .Dict [str , t .Any ]) -> None :
133
+ super ().__init__ (data )
134
+
135
+ self ._lower = {k .lower (): v for k , v in data .items ()}
136
+
137
+ def __getitem__ (self , key : str ) -> t .Any :
138
+ if key in self :
139
+ return super ().__getitem__ (key )
140
+ return self ._lower [key .lower ()]
141
+
142
+ def get (self , key : str , default : t .Any = None ) -> t .Any :
143
+ if key in self :
144
+ return super ().get (key , default )
145
+ return self ._lower .get (key .lower (), default )
146
+
147
+
131
148
class MacroDialect (Python ):
132
149
class Generator (Python .Generator ):
133
150
TRANSFORMS = {
@@ -222,7 +239,7 @@ def __init__(
222
239
for var_name , var_value in value .items ()
223
240
}
224
241
225
- self .locals [k ] = value
242
+ self .locals [k . lower () ] = value
226
243
227
244
def send (
228
245
self , name : str , * args : t .Any , ** kwargs : t .Any
@@ -256,14 +273,18 @@ def evaluate_macros(
256
273
changed = True
257
274
variables = self .variables
258
275
259
- if node .name not in self .locals and node .name .lower () not in variables :
276
+ # This makes all variables case-insensitive, e.g. @X is the same as @x. We do this
277
+ # for consistency, since `variables` and `blueprint_variables` are normalized.
278
+ var_name = node .name .lower ()
279
+
280
+ if var_name not in self .locals and var_name not in variables :
260
281
if not isinstance (node .parent , StagedFilePath ):
261
- raise SQLMeshError (f"Macro variable '{ node . name } ' is undefined." )
282
+ raise SQLMeshError (f"Macro variable '{ var_name } ' is undefined." )
262
283
263
284
return node
264
285
265
286
# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
266
- value = self .locals .get (node . name , variables .get (node . name . lower () ))
287
+ value = self .locals .get (var_name , variables .get (var_name ))
267
288
if isinstance (value , list ):
268
289
return exp .convert (
269
290
tuple (
@@ -313,11 +334,11 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
313
334
"""
314
335
# We try to convert all variables into sqlglot expressions because they're going to be converted
315
336
# into strings; in sql we don't convert strings because that would result in adding quotes
316
- mapping = {
337
+ base_mapping = {
317
338
k : convert_sql (v , self .dialect )
318
339
for k , v in chain (self .variables .items (), self .locals .items (), local_variables .items ())
319
340
}
320
- return MacroStrTemplate (str (text )).safe_substitute (mapping )
341
+ return MacroStrTemplate (str (text )).safe_substitute (CaseInsensitiveMapping ( base_mapping ) )
321
342
322
343
def evaluate (self , node : MacroFunc ) -> exp .Expression | t .List [exp .Expression ] | None :
323
344
if isinstance (node , MacroDef ):
@@ -327,7 +348,7 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
327
348
args [0 ] if len (args ) == 1 else exp .Tuple (expressions = list (args ))
328
349
)
329
350
else :
330
- self .locals [node .name ] = self .transform (node .expression )
351
+ self .locals [node .name . lower () ] = self .transform (node .expression )
331
352
return node
332
353
333
354
if isinstance (node , (MacroSQL , MacroStrReplace )):
0 commit comments