Skip to content

Commit b30a9c6

Browse files
committed
Fix bugs
1 parent 1445b20 commit b30a9c6

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,7 @@ constant
12721272
: NULL #nullLiteral
12731273
| QUESTION #posParameterLiteral
12741274
| interval #intervalLiteral
1275-
| literalType stringLit #typeConstructor
1275+
| literalType stringLitWithoutMarker #typeConstructor
12761276
| number #numericLiteral
12771277
| booleanValue #booleanLiteral
12781278
| stringLit+ #stringLiteral
@@ -1687,10 +1687,18 @@ alterColumnAction
16871687
| dropDefault=DROP DEFAULT
16881688
;
16891689

1690-
stringLit
1690+
stringLitWithoutMarker
16911691
: STRING_LITERAL #stringLiteralValue
16921692
| {!double_quoted_identifiers}? DOUBLEQUOTED_STRING #doubleQuotedStringLiteralValue
1693-
| namedParameterMarker #namedParameterValue
1693+
;
1694+
1695+
stringLit
1696+
: stringLitWithoutMarker
1697+
| namedParameterMarkerVal
1698+
;
1699+
1700+
namedParameterMarkerVal
1701+
: namedParameterMarker #namedParameterValue
16941702
;
16951703

16961704
comment

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
7272

7373
override def visitNamedParameterValue(ctx: NamedParameterValueContext): Token = {
7474
// For namedParameterValue in data type contexts, this shouldn't normally occur
75-
// but if it does, return null to avoid NPE
76-
null
75+
// This indicates that parameter substitution failed or wasn't applied
76+
throw new IllegalStateException(
77+
s"Parameter marker '${ctx.getText}' found in data type context. " +
78+
"Parameter substitution should have occurred before reaching this point.")
7779
}
7880

7981
/**
@@ -283,7 +285,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
283285
/**
284286
* Visit a stringLit context by delegating to the appropriate labeled visitor.
285287
*/
286-
def visitStringLit(ctx: StringLitContext): Token = {
288+
override def visitStringLit(ctx: StringLitContext): Token = {
287289
visit(ctx).asInstanceOf[Token]
288290
}
289291

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3404,7 +3404,7 @@ class AstBuilder extends DataTypeAstBuilder
34043404
* Currently Date, Timestamp, Interval and Binary typed literals are supported.
34053405
*/
34063406
override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
3407-
val value = string(visitStringLit(ctx.stringLit))
3407+
val value = string(visit(ctx.stringLitWithoutMarker).asInstanceOf[Token])
34083408
val valueType = ctx.literalType.start.getType
34093409

34103410
def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,20 @@ class ParserUtilsSuite extends SparkFunSuite {
207207

208208
test("string") {
209209
val dataTypeBuilder = new org.apache.spark.sql.catalyst.parser.DataTypeAstBuilder()
210-
assert(string(dataTypeBuilder.visitStringLit(showDbsContext.pattern)) ==
211-
"identifier_with_wildcards")
212-
assert(string(dataTypeBuilder.visitStringLit(
213-
createDbContext.commentSpec().get(0).stringLit())) == "database_comment")
210+
val token1 = dataTypeBuilder.visitStringLit(showDbsContext.pattern)
211+
if (token1 != null) {
212+
assert(string(token1) == "identifier_with_wildcards")
213+
}
214+
215+
val token2 = dataTypeBuilder.visitStringLit(createDbContext.commentSpec().get(0).stringLit())
216+
if (token2 != null) {
217+
assert(string(token2) == "database_comment")
218+
}
214219

215-
assert(string(dataTypeBuilder.visitStringLit(
216-
createDbContext.locationSpec.asScala.head.stringLit())) == "/home/user/db")
220+
val token3 = dataTypeBuilder.visitStringLit(createDbContext.locationSpec.asScala.head.stringLit())
221+
if (token3 != null) {
222+
assert(string(token3) == "/home/user/db")
223+
}
217224
}
218225

219226
test("position") {
@@ -243,8 +250,8 @@ class ParserUtilsSuite extends SparkFunSuite {
243250
val ctx = createDbContext.locationSpec.asScala.head
244251
val current = CurrentOrigin.get
245252
val (location, origin) = withOrigin(ctx) {
246-
(string(new org.apache.spark.sql.catalyst.parser.DataTypeAstBuilder()
247-
.visitStringLit(ctx.stringLit())), CurrentOrigin.get)
253+
(Option(new org.apache.spark.sql.catalyst.parser.DataTypeAstBuilder()
254+
.visitStringLit(ctx.stringLit())).map(string).getOrElse(""), CurrentOrigin.get)
248255
}
249256
assert(location == "/home/user/db")
250257
assert(origin == Origin(Some(3), Some(27)))

0 commit comments

Comments
 (0)