Skip to content

Commit 21dce79

Browse files
committed
[SPARK-54052][PYTHON] Add a bridge object to workaround Py4J limitation
### What changes were proposed in this pull request? This PR proposes to add PythonErrorUtils object to workaround Py4J limitation. Py4J does not support default method access. ### Why are the changes needed? To make the change easier and non error prone ### Does this PR introduce _any_ user-facing change? No. Virtually a refactoring change. ### How was this patch tested? Unittest was added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52755 from HyukjinKwon/bridge-class. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 477e6e8 commit 21dce79

File tree

3 files changed

+82
-38
lines changed

3 files changed

+82
-38
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.python
19+
20+
import java.util
21+
22+
import org.apache.spark.{BreakingChangeInfo, QueryContext, SparkThrowable}
23+
24+
/**
25+
* Utility object that provides convenient accessors for extracting
26+
* detailed information from a [[SparkThrowable]] instance.
27+
*
28+
* This object is primarily used in PySpark
29+
* to retrieve structured error metadata because Py4J does not work
30+
* with default methods.
31+
*/
32+
private[spark] object PythonErrorUtils {
33+
def getCondition(e: SparkThrowable): String = e.getCondition
34+
def getErrorClass(e: SparkThrowable): String = e.getCondition
35+
def getSqlState(e: SparkThrowable): String = e.getSqlState
36+
def isInternalError(e: SparkThrowable): Boolean = e.isInternalError
37+
def getBreakingChangeInfo(e: SparkThrowable): BreakingChangeInfo = e.getBreakingChangeInfo
38+
def getMessageParameters(e: SparkThrowable): util.Map[String, String] = e.getMessageParameters
39+
def getDefaultMessageTemplate(e: SparkThrowable): String = e.getDefaultMessageTemplate
40+
def getQueryContext(e: SparkThrowable): Array[QueryContext] = e.getQueryContext
41+
}

core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.deploy
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkFunSuite, SparkThrowable}
21+
import org.apache.spark.api.python.PythonErrorUtils
2122
import org.apache.spark.util.Utils
2223

2324
class PythonRunnerSuite extends SparkFunSuite {
@@ -64,4 +65,31 @@ class PythonRunnerSuite extends SparkFunSuite {
6465
intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
6566
intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
6667
}
68+
69+
test("SPARK-54052: PythonErrorUtils should have corresponding methods in SparkThrowable") {
70+
// Find default methods in SparkThrowable
71+
val defaultMethods = classOf[SparkThrowable]
72+
.getMethods
73+
.filter(m => m.getDeclaringClass == classOf[SparkThrowable])
74+
.map(_.getName)
75+
.toSet
76+
77+
// Find methods defined in PythonErrorUtils object
78+
val utilsMethods = PythonErrorUtils.getClass
79+
.getDeclaredMethods
80+
.filterNot(_.isSynthetic)
81+
.map(_.getName)
82+
.filterNot(_.contains("$"))
83+
.toSet
84+
85+
// Compare
86+
assert(
87+
utilsMethods == defaultMethods,
88+
s"""
89+
|PythonErrorUtils methods and SparkThrowable default methods differ!
90+
|Missing in PythonErrorUtils: ${defaultMethods.diff(utilsMethods).mkString(", ")}
91+
|Extra in PythonErrorUtils: ${utilsMethods.diff(defaultMethods).mkString(", ")}
92+
|""".stripMargin
93+
)
94+
}
6795
}

python/pyspark/errors/exceptions/captured.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def getCondition(self) -> Optional[str]:
107107
if self._origin is not None and is_instance_of(
108108
gw, self._origin, "org.apache.spark.SparkThrowable"
109109
):
110-
return self._origin.getCondition()
110+
utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
111+
return utils.getCondition(self._origin)
111112
else:
112113
return None
113114

@@ -118,68 +119,48 @@ def getErrorClass(self) -> Optional[str]:
118119
def getMessageParameters(self) -> Optional[Dict[str, str]]:
119120
from pyspark import SparkContext
120121
from py4j.java_gateway import is_instance_of
121-
from py4j.protocol import Py4JError
122122

123123
assert SparkContext._gateway is not None
124124

125125
gw = SparkContext._gateway
126126
if self._origin is not None and is_instance_of(
127127
gw, self._origin, "org.apache.spark.SparkThrowable"
128128
):
129-
try:
130-
return dict(self._origin.getMessageParameters())
131-
except Py4JError as e:
132-
if "py4j.Py4JException" in str(e) and "Method getMessageParameters" in str(e):
133-
return None
134-
raise e
129+
utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
130+
return dict(utils.getMessageParameters(self._origin))
135131
else:
136132
return None
137133

138134
def getSqlState(self) -> Optional[str]:
139135
from pyspark import SparkContext
140136
from py4j.java_gateway import is_instance_of
141-
from py4j.protocol import Py4JError
142137

143138
assert SparkContext._gateway is not None
144139
gw = SparkContext._gateway
145140
if self._origin is not None and is_instance_of(
146141
gw, self._origin, "org.apache.spark.SparkThrowable"
147142
):
148-
try:
149-
return self._origin.getSqlState()
150-
except Py4JError as e:
151-
if "py4j.Py4JException" in str(e) and "Method getSqlState" in str(e):
152-
return None
153-
raise e
143+
utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
144+
return utils.getSqlState(self._origin)
154145
else:
155146
return None
156147

157148
def getMessage(self) -> str:
158149
from pyspark import SparkContext
159150
from py4j.java_gateway import is_instance_of
160-
from py4j.protocol import Py4JError
161151

162152
assert SparkContext._gateway is not None
163153
gw = SparkContext._gateway
164154

165155
if self._origin is not None and is_instance_of(
166156
gw, self._origin, "org.apache.spark.SparkThrowable"
167157
):
168-
try:
169-
error_class = self._origin.getCondition()
170-
except Py4JError as e:
171-
if "py4j.Py4JException" in str(e) and "Method getCondition" in str(e):
172-
return ""
173-
raise e
174-
try:
175-
message_parameters = self._origin.getMessageParameters()
176-
except Py4JError as e:
177-
if "py4j.Py4JException" in str(e) and "Method getMessageParameters" in str(e):
178-
return ""
179-
raise e
158+
utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
159+
errorClass = utils.getCondition(self._origin)
160+
messageParameters = utils.getMessageParameters(self._origin)
180161

181162
error_message = getattr(gw.jvm, "org.apache.spark.SparkThrowableHelper").getMessage(
182-
error_class, message_parameters
163+
errorClass, messageParameters
183164
)
184165

185166
return error_message
@@ -189,7 +170,6 @@ def getMessage(self) -> str:
189170
def getQueryContext(self) -> List[BaseQueryContext]:
190171
from pyspark import SparkContext
191172
from py4j.java_gateway import is_instance_of
192-
from py4j.protocol import Py4JError
193173

194174
assert SparkContext._gateway is not None
195175

@@ -198,13 +178,8 @@ def getQueryContext(self) -> List[BaseQueryContext]:
198178
gw, self._origin, "org.apache.spark.SparkThrowable"
199179
):
200180
contexts: List[BaseQueryContext] = []
201-
try:
202-
context = self._origin.getQueryContext()
203-
except Py4JError as e:
204-
if "py4j.Py4JException" in str(e) and "Method getQueryContext" in str(e):
205-
return []
206-
raise e
207-
for q in context:
181+
utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
182+
for q in utils.getQueryContext(self._origin):
208183
if q.contextType().toString() == "SQL":
209184
contexts.append(SQLQueryContext(q))
210185
else:

0 commit comments

Comments
 (0)