Skip to content

Commit 1445b20

Browse files
committed
Pre-parser
1 parent fa90e98 commit 1445b20

File tree

7 files changed

+126
-41
lines changed

7 files changed

+126
-41
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ clusterBySpec
493493
bucketSpec
494494
: CLUSTERED BY identifierList
495495
(SORTED BY orderedIdentifierList)?
496-
INTO INTEGER_VALUE BUCKETS
496+
INTO integerValue BUCKETS
497497
;
498498

499499
skewSpec
@@ -583,7 +583,7 @@ ctes
583583
;
584584

585585
namedQuery
586-
: name=errorCapturingIdentifier (columnAliases=identifierList)? (MAX RECURSION LEVEL INTEGER_VALUE)? AS? LEFT_PAREN query RIGHT_PAREN
586+
: name=errorCapturingIdentifier (columnAliases=identifierList)? (MAX RECURSION LEVEL integerValue)? AS? LEFT_PAREN query RIGHT_PAREN
587587
;
588588

589589
tableProvider
@@ -970,13 +970,13 @@ joinCriteria
970970
;
971971

972972
sample
973-
: TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN seed=INTEGER_VALUE RIGHT_PAREN)?
973+
: TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN seed=integerValue RIGHT_PAREN)?
974974
;
975975

976-
sampleMethod
977-
: negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile
976+
sampleMethod
977+
: negativeSign=MINUS? (integerValue | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile
978978
| expression ROWS #sampleByRows
979-
| sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE
979+
| sampleType=BUCKET numerator=integerValue OUT OF denominator=integerValue
980980
(ON (identifier | qualifiedName LEFT_PAREN RIGHT_PAREN))? #sampleByBucket
981981
| bytes=expression #sampleByBytes
982982
;
@@ -1249,13 +1249,13 @@ jsonPathBracketedIdentifier
12491249
jsonPathFirstPart
12501250
: jsonPathIdentifier
12511251
| jsonPathBracketedIdentifier
1252-
| LEFT_BRACKET INTEGER_VALUE RIGHT_BRACKET
1252+
| LEFT_BRACKET integerValue RIGHT_BRACKET
12531253
;
12541254

12551255
jsonPathParts
12561256
: DOT jsonPathIdentifier
12571257
| jsonPathBracketedIdentifier
1258-
| LEFT_BRACKET INTEGER_VALUE RIGHT_BRACKET
1258+
| LEFT_BRACKET integerValue RIGHT_BRACKET
12591259
| LEFT_BRACKET identifier RIGHT_BRACKET
12601260
;
12611261

@@ -1271,14 +1271,17 @@ literalType
12711271
constant
12721272
: NULL #nullLiteral
12731273
| QUESTION #posParameterLiteral
1274-
| COLON identifier #namedParameterLiteral
12751274
| interval #intervalLiteral
12761275
| literalType stringLit #typeConstructor
12771276
| number #numericLiteral
12781277
| booleanValue #booleanLiteral
12791278
| stringLit+ #stringLiteral
1279+
| namedParameterMarker #namedParameterLiteral
12801280
;
12811281

1282+
namedParameterMarker
1283+
: COLON identifier
1284+
;
12821285
comparisonOperator
12831286
: EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
12841287
;
@@ -1344,15 +1347,15 @@ collateClause
13441347

13451348
nonTrivialPrimitiveType
13461349
: STRING collateClause?
1347-
| (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
1348-
| VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
1350+
| (CHARACTER | CHAR) (LEFT_PAREN length=integerValue RIGHT_PAREN)?
1351+
| VARCHAR (LEFT_PAREN length=integerValue RIGHT_PAREN)?
13491352
| (DECIMAL | DEC | NUMERIC)
1350-
(LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? RIGHT_PAREN)?
1353+
(LEFT_PAREN precision=integerValue (COMMA scale=integerValue)? RIGHT_PAREN)?
13511354
| INTERVAL
13521355
(fromYearMonth=(YEAR | MONTH) (TO to=MONTH)? |
13531356
fromDayTime=(DAY | HOUR | MINUTE | SECOND) (TO to=(HOUR | MINUTE | SECOND))?)?
13541357
| TIMESTAMP (WITHOUT TIME ZONE)?
1355-
| TIME (LEFT_PAREN precision=INTEGER_VALUE RIGHT_PAREN)? (WITHOUT TIME ZONE)?
1358+
| TIME (LEFT_PAREN precision=integerValue RIGHT_PAREN)? (WITHOUT TIME ZONE)?
13561359
;
13571360

13581361
trivialPrimitiveType
@@ -1373,7 +1376,7 @@ trivialPrimitiveType
13731376
primitiveType
13741377
: nonTrivialPrimitiveType
13751378
| trivialPrimitiveType
1376-
| unsupportedType=identifier (LEFT_PAREN INTEGER_VALUE(COMMA INTEGER_VALUE)* RIGHT_PAREN)?
1379+
| unsupportedType=identifier (LEFT_PAREN integerValue(COMMA integerValue)* RIGHT_PAREN)?
13771380
;
13781381

13791382
dataType
@@ -1454,7 +1457,7 @@ sequenceGeneratorOption
14541457
;
14551458

14561459
sequenceGeneratorStartOrStep
1457-
: MINUS? INTEGER_VALUE
1460+
: MINUS? integerValue
14581461
| MINUS? BIGINT_LITERAL
14591462
;
14601463

@@ -1606,6 +1609,11 @@ number
16061609
| MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
16071610
;
16081611

1612+
integerValue
1613+
: INTEGER_VALUE #integerVal
1614+
| namedParameterMarker #namedParameterIntegerValue
1615+
;
1616+
16091617
columnConstraintDefinition
16101618
: (CONSTRAINT name=errorCapturingIdentifier)? columnConstraint constraintCharacteristic*
16111619
;
@@ -1680,8 +1688,9 @@ alterColumnAction
16801688
;
16811689

16821690
stringLit
1683-
: STRING_LITERAL
1684-
| {!double_quoted_identifiers}? DOUBLEQUOTED_STRING
1691+
: STRING_LITERAL #stringLiteralValue
1692+
| {!double_quoted_identifiers}? DOUBLEQUOTED_STRING #doubleQuotedStringLiteralValue
1693+
| namedParameterMarker #namedParameterValue
16851694
;
16861695

16871696
comment

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,37 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
4545
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
4646
}
4747

48-
override def visitStringLit(ctx: StringLitContext): Token = {
48+
override def visitStringLiteralValue(ctx: StringLiteralValueContext): Token = {
4949
if (ctx != null) {
50-
if (ctx.STRING_LITERAL != null) {
51-
ctx.STRING_LITERAL.getSymbol
52-
} else {
53-
ctx.DOUBLEQUOTED_STRING.getSymbol
54-
}
50+
ctx.STRING_LITERAL.getSymbol
51+
} else {
52+
null
53+
}
54+
}
55+
56+
override def visitDoubleQuotedStringLiteralValue(
57+
ctx: DoubleQuotedStringLiteralValueContext): Token = {
58+
if (ctx != null) {
59+
ctx.DOUBLEQUOTED_STRING.getSymbol
60+
} else {
61+
null
62+
}
63+
}
64+
65+
override def visitIntegerVal(ctx: IntegerValContext): Token = {
66+
if (ctx != null) {
67+
ctx.INTEGER_VALUE.getSymbol
5568
} else {
5669
null
5770
}
5871
}
5972

73+
override def visitNamedParameterValue(ctx: NamedParameterValueContext): Token = {
74+
// For namedParameterValue in data type contexts, this shouldn't normally occur
75+
// but if it does, return null to avoid NPE
76+
null
77+
}
78+
6079
/**
6180
* Create a multi-part identifier.
6281
*/
@@ -138,7 +157,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
138157
}
139158
} else {
140159
val badType = typeCtx.unsupportedType.getText
141-
val params = typeCtx.INTEGER_VALUE().asScala.toList
160+
val params = typeCtx.integerValue().asScala.toList
142161
val dtStr =
143162
if (params.nonEmpty) s"$badType(${params.mkString(",")})"
144163
else badType
@@ -258,7 +277,14 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
258277
* Create a comment string.
259278
*/
260279
override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) {
261-
string(visitStringLit(ctx.stringLit))
280+
string(visit(ctx.stringLit).asInstanceOf[Token])
281+
}
282+
283+
/**
284+
* Visit a stringLit context by delegating to the appropriate labeled visitor.
285+
*/
286+
def visitStringLit(ctx: StringLitContext): Token = {
287+
visit(ctx).asInstanceOf[Token]
262288
}
263289

264290
/**

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParmsAstBuilder.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class SubstituteParmsAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
5454
*/
5555
override def visitNamedParameterLiteral(
5656
ctx: NamedParameterLiteralContext): String = withOrigin(ctx) {
57-
val paramName = ctx.identifier().getText
57+
val paramName = ctx.namedParameterMarker().identifier().getText
5858
namedParams += paramName
5959

6060
// Calculate the location of the entire parameter (including the colon)
@@ -81,6 +81,40 @@ class SubstituteParmsAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
8181
"?"
8282
}
8383

84+
/**
85+
* Handle namedParameterValue context for parameter markers in string literal contexts.
86+
* This handles the namedParameterMarker case added to the stringLit grammar rule.
87+
*/
88+
override def visitNamedParameterValue(
89+
ctx: NamedParameterValueContext): String = withOrigin(ctx) {
90+
val paramName = ctx.namedParameterMarker().identifier().getText
91+
namedParams += paramName
92+
93+
// Calculate the location of the entire parameter (including the colon)
94+
val startIndex = ctx.getStart.getStartIndex
95+
val stopIndex = ctx.getStop.getStopIndex + 1
96+
namedParamLocations(paramName) = ParameterLocation(startIndex, stopIndex)
97+
98+
paramName
99+
}
100+
101+
/**
102+
* Handle namedParameterIntegerValue context for parameter markers in integer value contexts.
103+
* This handles the namedParameterMarker case added to the integerValue grammar rule.
104+
*/
105+
override def visitNamedParameterIntegerValue(
106+
ctx: NamedParameterIntegerValueContext): String = withOrigin(ctx) {
107+
val paramName = ctx.namedParameterMarker().identifier().getText
108+
namedParams += paramName
109+
110+
// Calculate the location of the entire parameter (including the colon)
111+
val startIndex = ctx.getStart.getStartIndex
112+
val stopIndex = ctx.getStop.getStopIndex + 1
113+
namedParamLocations(paramName) = ParameterLocation(startIndex, stopIndex)
114+
115+
paramName
116+
}
117+
84118
/**
85119
* Override visit to ensure we traverse all children to find parameters.
86120
*/
@@ -93,6 +127,10 @@ class SubstituteParmsAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
93127
visitNamedParameterLiteral(ctx)
94128
case ctx: PosParameterLiteralContext =>
95129
visitPosParameterLiteral(ctx)
130+
case ctx: NamedParameterValueContext =>
131+
visitNamedParameterValue(ctx)
132+
case ctx: NamedParameterIntegerValueContext =>
133+
visitNamedParameterIntegerValue(ctx)
96134
case ruleNode: RuleNode =>
97135
// Continue traversing children for rule nodes
98136
visitChildren(ruleNode)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -713,12 +713,12 @@ class AstBuilder extends DataTypeAstBuilder
713713
private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = {
714714
val ctes = ctx.namedQuery.asScala.map { nCtx =>
715715
val namedQuery = visitNamedQuery(nCtx)
716-
val rowLevelLimit: Option[Int] = if (nCtx.INTEGER_VALUE() != null) {
716+
val rowLevelLimit: Option[Int] = if (nCtx.integerValue() != null) {
717717
if (ctx.RECURSIVE() == null) {
718718
operationNotAllowed("Cannot specify MAX RECURSION LEVEL when the CTE is not marked as " +
719719
"RECURSIVE", ctx)
720720
}
721-
Some(nCtx.INTEGER_VALUE().getText().toInt)
721+
Some(nCtx.integerValue().getText().toInt)
722722
} else {
723723
None
724724
}
@@ -2195,7 +2195,8 @@ class AstBuilder extends DataTypeAstBuilder
21952195
Limit(expression(ctx.expression), query)
21962196

21972197
case ctx: SampleByPercentileContext =>
2198-
val fraction = ctx.percentage.getText.toDouble
2198+
val fraction = if (ctx.DECIMAL_VALUE() != null) { ctx.DECIMAL_VALUE().getText.toDouble }
2199+
else { ctx.integerValue().getText.toDouble }
21992200
val sign = if (ctx.negativeSign == null) 1 else -1
22002201
sample(sign * fraction / 100.0d, seed)
22012202

@@ -2252,7 +2253,7 @@ class AstBuilder extends DataTypeAstBuilder
22522253

22532254
override def visitVersion(ctx: VersionContext): Option[String] = {
22542255
if (ctx != null) {
2255-
if (ctx.INTEGER_VALUE != null) {
2256+
if (ctx.INTEGER_VALUE() != null) {
22562257
Some(ctx.INTEGER_VALUE().getText)
22572258
} else {
22582259
Option(string(visitStringLit(ctx.stringLit())))
@@ -4182,7 +4183,7 @@ class AstBuilder extends DataTypeAstBuilder
41824183
*/
41834184
override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
41844185
BucketSpec(
4185-
ctx.INTEGER_VALUE.getText.toInt,
4186+
ctx.integerValue().getText.toInt,
41864187
visitIdentifierList(ctx.identifierList),
41874188
Option(ctx.orderedIdentifierList)
41884189
.toSeq
@@ -6362,7 +6363,15 @@ class AstBuilder extends DataTypeAstBuilder
63626363
* */
63636364
override def visitNamedParameterLiteral(
63646365
ctx: NamedParameterLiteralContext): Expression = withOrigin(ctx) {
6365-
NamedParameter(ctx.identifier().getText)
6366+
NamedParameter(ctx.namedParameterMarker().identifier().getText)
6367+
}
6368+
6369+
/**
6370+
* Create a named parameter in integer value context.
6371+
* */
6372+
override def visitNamedParameterIntegerValue(
6373+
ctx: NamedParameterIntegerValueContext): Expression = withOrigin(ctx) {
6374+
NamedParameter(ctx.namedParameterMarker().identifier().getText)
63666375
}
63676376

63686377
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,14 @@ class ParserUtilsSuite extends SparkFunSuite {
206206
}
207207

208208
test("string") {
209-
assert(string(showDbsContext.pattern.STRING_LITERAL()) == "identifier_with_wildcards")
210-
assert(string(createDbContext.commentSpec().get(0).stringLit().STRING_LITERAL()) ==
211-
"database_comment")
212-
213-
assert(string(createDbContext.locationSpec.asScala.head.stringLit().STRING_LITERAL()) ==
214-
"/home/user/db")
209+
val dataTypeBuilder = new org.apache.spark.sql.catalyst.parser.DataTypeAstBuilder()
210+
assert(string(dataTypeBuilder.visitStringLit(showDbsContext.pattern)) ==
211+
"identifier_with_wildcards")
212+
assert(string(dataTypeBuilder.visitStringLit(
213+
createDbContext.commentSpec().get(0).stringLit())) == "database_comment")
214+
215+
assert(string(dataTypeBuilder.visitStringLit(
216+
createDbContext.locationSpec.asScala.head.stringLit())) == "/home/user/db")
215217
}
216218

217219
test("position") {
@@ -241,7 +243,8 @@ class ParserUtilsSuite extends SparkFunSuite {
241243
val ctx = createDbContext.locationSpec.asScala.head
242244
val current = CurrentOrigin.get
243245
val (location, origin) = withOrigin(ctx) {
244-
(string(ctx.stringLit().STRING_LITERAL), CurrentOrigin.get)
246+
(string(new org.apache.spark.sql.catalyst.parser.DataTypeAstBuilder()
247+
.visitStringLit(ctx.stringLit())), CurrentOrigin.get)
245248
}
246249
assert(location == "/home/user/db")
247250
assert(origin == Origin(Some(3), Some(27)))

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.sql
4040
import org.apache.spark.sql.{Artifact, DataSourceRegistration, Encoder, Encoders, ExperimentalMethods, Row, SparkSessionBuilder, SparkSessionCompanion, SparkSessionExtensions, SparkSessionExtensionsProvider, UDTFRegistration}
4141
import org.apache.spark.sql.artifact.ArtifactManager
4242
import org.apache.spark.sql.catalyst._
43-
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
43+
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
4444
import org.apache.spark.sql.catalyst.encoders._
4545
import org.apache.spark.sql.catalyst.expressions.AttributeReference
4646
import org.apache.spark.sql.catalyst.parser.{NamedParameterContext, ParserInterface, PositionalParameterContext, ThreadLocalParameterContext}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SparkSqlParser extends AbstractSqlParser {
5555

5656
protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
5757
// Step 1: Check if we have a parameterized query context and substitute parameters
58-
val paramSubstituted =
58+
val paramSubstituted =
5959
org.apache.spark.sql.catalyst.parser.ThreadLocalParameterContext.get() match {
6060
case Some(context) =>
6161
substituteParametersIfNeeded(command, context)

0 commit comments

Comments
 (0)