Skip to content

Commit 50f6d93

Browse files
yaooqinncloud-fan
authored andcommitted
[SPARK-29870][SQL] Unify the logic of multi-units interval string to CalendarInterval
### What changes were proposed in this pull request? We now have two different implementation for multi-units interval strings to CalendarInterval type values. One is used to covert interval string literals to CalendarInterval. This approach will re-delegate the interval string to spark parser which handles the string as a `singleInterval` -> `multiUnitsInterval` -> eventually call `IntervalUtils.fromUnitStrings` The other is used in `Cast`, which eventually calls `IntervalUtils.stringToInterval`. This approach is ~10 times faster than the other. We should unify these two for better performance and simple logic. this pr uses the 2nd approach. ### Why are the changes needed? We should unify these two for better performance and simple logic. ### Does this PR introduce any user-facing change? no ### How was this patch tested? we shall not fail on existing uts Closes #26491 from yaooqinn/SPARK-29870. Authored-by: Kent Yao <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5cebe58 commit 50f6d93

File tree

19 files changed

+169
-222
lines changed

19 files changed

+169
-222
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,6 @@ singleTableSchema
7979
: colTypeList EOF
8080
;
8181

82-
singleInterval
83-
: INTERVAL? multiUnitsInterval EOF
84-
;
85-
8682
statement
8783
: query #statementDefault
8884
| ctes? dmlStatementNoWith #dmlStatement

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
477477
// IntervalConverter
478478
private[this] def castToInterval(from: DataType): Any => Any = from match {
479479
case StringType =>
480-
buildCast[UTF8String](_, s => IntervalUtils.stringToInterval(s))
480+
buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
481481
}
482482

483483
// LongConverter
@@ -1234,7 +1234,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
12341234
case StringType =>
12351235
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
12361236
(c, evPrim, evNull) =>
1237-
code"""$evPrim = $util.stringToInterval($c);
1237+
code"""$evPrim = $util.safeStringToInterval($c);
12381238
if(${evPrim} == null) {
12391239
${evNull} = true;
12401240
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2525
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
2626
import org.apache.spark.sql.catalyst.util.IntervalUtils
2727
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
2829

2930
case class TimeWindow(
3031
timeColumn: Expression,
@@ -103,7 +104,7 @@ object TimeWindow {
103104
* precision.
104105
*/
105106
private def getIntervalInMicroSeconds(interval: String): Long = {
106-
val cal = IntervalUtils.fromString(interval)
107+
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
107108
if (cal.months != 0) {
108109
throw new IllegalArgumentException(
109110
s"Intervals greater than a month is not supported ($interval).")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
102102
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
103103
}
104104

105-
override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
106-
withOrigin(ctx)(visitMultiUnitsInterval(ctx.multiUnitsInterval))
107-
}
108-
109105
/* ********************************************************************************************
110106
* Plan parsing
111107
* ******************************************************************************************** */
@@ -1870,7 +1866,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
18701866
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
18711867
case "INTERVAL" =>
18721868
val interval = try {
1873-
IntervalUtils.fromString(value)
1869+
IntervalUtils.stringToInterval(UTF8String.fromString(value))
18741870
} catch {
18751871
case e: IllegalArgumentException =>
18761872
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
@@ -2069,22 +2065,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
20692065
*/
20702066
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
20712067
withOrigin(ctx) {
2072-
val units = ctx.intervalUnit().asScala.map { unit =>
2073-
val u = unit.getText.toLowerCase(Locale.ROOT)
2074-
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
2075-
if (u.endsWith("s")) u.substring(0, u.length - 1) else u
2076-
}.map(IntervalUtils.IntervalUnit.withName).toArray
2077-
2078-
val values = ctx.intervalValue().asScala.map { value =>
2079-
if (value.STRING() != null) {
2080-
string(value.STRING())
2081-
} else {
2082-
value.getText
2083-
}
2084-
}.toArray
2085-
2068+
val units = ctx.intervalUnit().asScala
2069+
val values = ctx.intervalValue().asScala
20862070
try {
2087-
IntervalUtils.fromUnitStrings(units, values)
2071+
assert(units.length == values.length)
2072+
val kvs = units.indices.map { i =>
2073+
val u = units(i).getText
2074+
val v = if (values(i).STRING() != null) {
2075+
string(values(i).STRING())
2076+
} else {
2077+
values(i).getText
2078+
}
2079+
UTF8String.fromString(" " + v + " " + u)
2080+
}
2081+
IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
20882082
} catch {
20892083
case i: IllegalArgumentException =>
20902084
val e = new ParseException(i.getMessage, ctx)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,12 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2929
import org.apache.spark.sql.catalyst.trees.Origin
3030
import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.types.{DataType, StructType}
32-
import org.apache.spark.unsafe.types.CalendarInterval
3332

3433
/**
3534
* Base SQL parsing infrastructure.
3635
*/
3736
abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {
3837

39-
/**
40-
* Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
41-
* string is not a valid interval format.
42-
*/
43-
def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
44-
astBuilder.visitSingleInterval(parser.singleInterval())
45-
}
46-
4738
/** Creates/Resolves DataType for a given SQL string. */
4839
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
4940
astBuilder.visitSingleDataType(parser.singleDataType())

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 41 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.util.concurrent.TimeUnit
2222

2323
import scala.util.control.NonFatal
2424

25-
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
2625
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
2726
import org.apache.spark.sql.types.Decimal
2827
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -102,34 +101,6 @@ object IntervalUtils {
102101
Decimal(result, 18, 6)
103102
}
104103

105-
/**
106-
* Converts a string to [[CalendarInterval]] case-insensitively.
107-
*
108-
* @throws IllegalArgumentException if the input string is not in valid interval format.
109-
*/
110-
def fromString(str: String): CalendarInterval = {
111-
if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
112-
try {
113-
CatalystSqlParser.parseInterval(str)
114-
} catch {
115-
case e: ParseException =>
116-
val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
117-
ex.setStackTrace(e.getStackTrace)
118-
throw ex
119-
}
120-
}
121-
122-
/**
123-
* A safe version of `fromString`. It returns null for invalid input string.
124-
*/
125-
def safeFromString(str: String): CalendarInterval = {
126-
try {
127-
fromString(str)
128-
} catch {
129-
case _: IllegalArgumentException => null
130-
}
131-
}
132-
133104
private def toLongWithRange(
134105
fieldName: IntervalUnit,
135106
s: String,
@@ -251,46 +222,6 @@ object IntervalUtils {
251222
}
252223
}
253224

254-
def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
255-
assert(units.length == values.length)
256-
var months: Int = 0
257-
var days: Int = 0
258-
var microseconds: Long = 0
259-
var i = 0
260-
while (i < units.length) {
261-
try {
262-
units(i) match {
263-
case YEAR =>
264-
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
265-
case MONTH =>
266-
months = Math.addExact(months, values(i).toInt)
267-
case WEEK =>
268-
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
269-
case DAY =>
270-
days = Math.addExact(days, values(i).toInt)
271-
case HOUR =>
272-
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
273-
microseconds = Math.addExact(microseconds, hoursUs)
274-
case MINUTE =>
275-
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
276-
microseconds = Math.addExact(microseconds, minutesUs)
277-
case SECOND =>
278-
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
279-
case MILLISECOND =>
280-
val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
281-
microseconds = Math.addExact(microseconds, millisUs)
282-
case MICROSECOND =>
283-
microseconds = Math.addExact(microseconds, values(i).toLong)
284-
}
285-
} catch {
286-
case e: Exception =>
287-
throw new IllegalArgumentException(s"Error parsing interval string: ${e.getMessage}", e)
288-
}
289-
i += 1
290-
}
291-
new CalendarInterval(months, days, microseconds)
292-
}
293-
294225
// Parses a string with nanoseconds, truncates the result and returns microseconds
295226
private def parseNanos(nanosStr: String, isNegative: Boolean): Long = {
296227
if (nanosStr != null) {
@@ -306,30 +237,6 @@ object IntervalUtils {
306237
}
307238
}
308239

309-
/**
310-
* Parse second_nano string in ss.nnnnnnnnn format to microseconds
311-
*/
312-
private def parseSecondNano(secondNano: String): Long = {
313-
def parseSeconds(secondsStr: String): Long = {
314-
toLongWithRange(
315-
SECOND,
316-
secondsStr,
317-
Long.MinValue / MICROS_PER_SECOND,
318-
Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
319-
}
320-
321-
secondNano.split("\\.") match {
322-
case Array(secondsStr) => parseSeconds(secondsStr)
323-
case Array("", nanosStr) => parseNanos(nanosStr, false)
324-
case Array(secondsStr, nanosStr) =>
325-
val seconds = parseSeconds(secondsStr)
326-
Math.addExact(seconds, parseNanos(nanosStr, seconds < 0))
327-
case _ =>
328-
throw new IllegalArgumentException(
329-
"Interval string does not match second-nano format of ss.nnnnnnnnn")
330-
}
331-
}
332-
333240
/**
334241
* Gets interval duration
335242
*
@@ -558,18 +465,37 @@ object IntervalUtils {
558465
private final val millisStr = unitToUtf8(MILLISECOND)
559466
private final val microsStr = unitToUtf8(MICROSECOND)
560467

468+
/**
469+
* A safe version of `stringToInterval`. It returns null for invalid input string.
470+
*/
471+
def safeStringToInterval(input: UTF8String): CalendarInterval = {
472+
try {
473+
stringToInterval(input)
474+
} catch {
475+
case _: IllegalArgumentException => null
476+
}
477+
}
478+
479+
/**
480+
* Converts a string to [[CalendarInterval]] case-insensitively.
481+
*
482+
* @throws IllegalArgumentException if the input string is not in valid interval format.
483+
*/
561484
def stringToInterval(input: UTF8String): CalendarInterval = {
562485
import ParseState._
486+
def throwIAE(msg: String, e: Exception = null) = {
487+
throw new IllegalArgumentException(s"Error parsing '$input' to interval, $msg", e)
488+
}
563489

564490
if (input == null) {
565-
return null
491+
throwIAE("interval string cannot be null")
566492
}
567493
// scalastyle:off caselocale .toLowerCase
568494
val s = input.trim.toLowerCase
569495
// scalastyle:on
570496
val bytes = s.getBytes
571497
if (bytes.isEmpty) {
572-
return null
498+
throwIAE("interval string cannot be empty")
573499
}
574500
var state = PREFIX
575501
var i = 0
@@ -588,13 +514,19 @@ object IntervalUtils {
588514
}
589515
}
590516

517+
def currentWord: UTF8String = {
518+
val strings = s.split(UTF8String.blankString(1), -1)
519+
val lenRight = s.substring(i, s.numBytes()).split(UTF8String.blankString(1), -1).length
520+
strings(strings.length - lenRight)
521+
}
522+
591523
while (i < bytes.length) {
592524
val b = bytes(i)
593525
state match {
594526
case PREFIX =>
595527
if (s.startsWith(intervalStr)) {
596528
if (s.numBytes() == intervalStr.numBytes()) {
597-
return null
529+
throwIAE("interval string cannot be empty")
598530
} else {
599531
i += intervalStr.numBytes()
600532
}
@@ -627,7 +559,7 @@ object IntervalUtils {
627559
fractionScale = (NANOS_PER_SECOND / 10).toInt
628560
i += 1
629561
state = VALUE_FRACTIONAL_PART
630-
case _ => return null
562+
case _ => throwIAE( s"unrecognized number '$currentWord'")
631563
}
632564
case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE)
633565
case VALUE =>
@@ -636,13 +568,13 @@ object IntervalUtils {
636568
try {
637569
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
638570
} catch {
639-
case _: ArithmeticException => return null
571+
case e: ArithmeticException => throwIAE(e.getMessage, e)
640572
}
641573
case ' ' => state = TRIM_BEFORE_UNIT
642574
case '.' =>
643575
fractionScale = (NANOS_PER_SECOND / 10).toInt
644576
state = VALUE_FRACTIONAL_PART
645-
case _ => return null
577+
case _ => throwIAE(s"invalid value '$currentWord'")
646578
}
647579
i += 1
648580
case VALUE_FRACTIONAL_PART =>
@@ -653,14 +585,17 @@ object IntervalUtils {
653585
case ' ' =>
654586
fraction /= NANOS_PER_MICROS.toInt
655587
state = TRIM_BEFORE_UNIT
656-
case _ => return null
588+
case _ if '0' <= b && b <= '9' =>
589+
throwIAE(s"interval can only support nanosecond precision, '$currentWord' is out" +
590+
s" of range")
591+
case _ => throwIAE(s"invalid value '$currentWord'")
657592
}
658593
i += 1
659594
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
660595
case UNIT_BEGIN =>
661596
// Checks that only seconds can have the fractional part
662597
if (b != 's' && fractionScale >= 0) {
663-
return null
598+
throwIAE(s"'$currentWord' cannot have fractional part")
664599
}
665600
if (isNegative) {
666601
currentValue = -currentValue
@@ -704,26 +639,26 @@ object IntervalUtils {
704639
} else if (s.matchAt(microsStr, i)) {
705640
microseconds = Math.addExact(microseconds, currentValue)
706641
i += microsStr.numBytes()
707-
} else return null
708-
case _ => return null
642+
} else throwIAE(s"invalid unit '$currentWord'")
643+
case _ => throwIAE(s"invalid unit '$currentWord'")
709644
}
710645
} catch {
711-
case _: ArithmeticException => return null
646+
case e: ArithmeticException => throwIAE(e.getMessage, e)
712647
}
713648
state = UNIT_SUFFIX
714649
case UNIT_SUFFIX =>
715650
b match {
716651
case 's' => state = UNIT_END
717652
case ' ' => state = TRIM_BEFORE_SIGN
718-
case _ => return null
653+
case _ => throwIAE(s"invalid unit '$currentWord'")
719654
}
720655
i += 1
721656
case UNIT_END =>
722657
b match {
723658
case ' ' =>
724659
i += 1
725660
state = TRIM_BEFORE_SIGN
726-
case _ => return null
661+
case _ => throwIAE(s"invalid unit '$currentWord'")
727662
}
728663
}
729664
}

0 commit comments

Comments
 (0)