Skip to content

Commit facb61f

Browse files
committed
add input tests
1 parent 7e71571 commit facb61f

8 files changed

+618
-14
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
These tests capture input/output type interfaces between python udfs and the engine.
2+
3+
## When this test fails:
4+
- Look at the diff in the test output
5+
- To regenerate golden files, simply delete the existing golden file and re-run the test.

python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt

Lines changed: 43 additions & 0 deletions
Large diffs are not rendered by default.

python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt

Lines changed: 43 additions & 0 deletions
Large diffs are not rendered by default.

python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt

Lines changed: 43 additions & 0 deletions
Large diffs are not rendered by default.

python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt

Lines changed: 43 additions & 0 deletions
Large diffs are not rendered by default.

python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py

Lines changed: 403 additions & 0 deletions
Large diffs are not rendered by default.

python/pyspark/sql/tests/udf_type_tests/test_udf_types.py renamed to python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# http://www.apache.org/licenses/LICENSE-2.0
1010
#
1111
# Unless required by applicable law or agreed to in writing, software
12-
# distributed under the Apache License is distributed on an "AS IS" BASIS,
12+
# distributed under the License is distributed on an "AS IS" BASIS,
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
@@ -45,21 +45,21 @@
4545
TimestampType,
4646
)
4747
from pyspark.testing.sqlutils import ReusedSQLTestCase
48-
from .type_table_utils import generate_word_diff, format_type_table
48+
from .type_table_utils import generate_table_diff, format_type_table
4949

5050

5151
# Note: The values of 'SQL Type' are DDL formatted strings, which can be used as `returnType`s.
5252
# Note: The values inside the table are generated by `repr`. X' means it throws an exception
5353
# during the conversion.
5454
# Note: Python 3.11.9, Pandas 2.2.3 and PyArrow 17.0.0 are used.
5555
# Note: 'X' means it throws an exception during the conversion.
56-
class UDFTypeTests(ReusedSQLTestCase):
56+
class UDFReturnTypeTests(ReusedSQLTestCase):
5757
@classmethod
5858
def setUpClass(cls):
59-
super(UDFTypeTests, cls).setUpClass()
59+
super(UDFReturnTypeTests, cls).setUpClass()
6060

6161
def setUp(self):
62-
super(UDFTypeTests, self).setUp()
62+
super(UDFReturnTypeTests, self).setUp()
6363
self.test_data = [
6464
None,
6565
True,
@@ -185,7 +185,7 @@ def _compare_or_create_golden_file(self, actual_output, golden_file, test_name):
185185
expected_output = f.read()
186186

187187
if actual_output != expected_output:
188-
diff_output = generate_word_diff(actual_output, expected_output)
188+
diff_output = generate_table_diff(actual_output, expected_output)
189189
self.fail(
190190
f"""
191191
Results don't match golden file for \:{test_name}\".\n

python/pyspark/sql/tests/udf_type_tests/type_table_utils.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
11
#!/usr/bin/env python3
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one or more
4+
# contributor license agreements. See the NOTICE file distributed with
5+
# this work for additional information regarding copyright ownership.
6+
# The ASF licenses this file to You under the Apache License, Version 2.0
7+
# (the "License"); you may not use this file except in compliance with
8+
# the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
218

319
import os
420
import sys
@@ -86,6 +102,7 @@ def format_table_diff(
86102
expected_rows: List[List[str]],
87103
actual_rows: List[List[str]],
88104
use_colors: bool = True,
105+
cell_width: int = CELL_WIDTH,
89106
) -> str:
90107
"""Format a table diff with cell-level highlighting."""
91108
output_lines = []
@@ -96,7 +113,7 @@ def format_table_diff(
96113
)
97114
output_lines.append("=" * len(title))
98115

99-
col_widths = [CELL_WIDTH] * len(header)
116+
col_widths = [cell_width] * len(header)
100117

101118
def format_row(cells: List[str], prefix: str = "", color: str = "") -> str:
102119
"""Format a single row with proper alignment."""
@@ -171,7 +188,7 @@ def create_border(char: str = "-") -> str:
171188
if expected_cell != actual_cell:
172189
row_has_changes = True
173190
diff_cell = highlight_cell_diff(
174-
expected_cell, actual_cell, use_colors, CELL_WIDTH
191+
expected_cell, actual_cell, use_colors, cell_width
175192
)
176193
diff_row.append(diff_cell)
177194
else:
@@ -205,15 +222,16 @@ def create_border(char: str = "-") -> str:
205222
return "\n".join(output_lines)
206223

207224

208-
def generate_word_diff(actual, expected):
225+
def generate_table_diff(actual, expected, cell_width=CELL_WIDTH):
209226
"""Generate a table-aware diff between actual and expected output."""
210227
try:
211228
expected_header, expected_rows = parse_table_content(expected)
212229
actual_header, actual_rows = parse_table_content(actual)
213230

214231
if expected_header and actual_header:
215-
use_colors = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
216-
return format_table_diff(expected_header, expected_rows, actual_rows, use_colors)
232+
return format_table_diff(
233+
expected_header, expected_rows, actual_rows, True, cell_width
234+
)
217235
except Exception:
218236
pass
219237

@@ -257,7 +275,7 @@ def format_type_table(results, header, column_width=30):
257275
return "\n".join(output_lines)
258276

259277

260-
def compare_files(file1_path, file2_path):
278+
def compare_files(file1_path, file2_path, cell_width=CELL_WIDTH):
261279
"""Compare two files and show the differences."""
262280
if not os.path.exists(file1_path):
263281
print(f"Error: File '{file1_path}' does not exist")
@@ -287,7 +305,7 @@ def compare_files(file1_path, file2_path):
287305
print("Files differ. Generating word-wise diff...")
288306
print()
289307

290-
diff_output = generate_word_diff(content2, content1)
308+
diff_output = generate_table_diff(content2, content1, cell_width)
291309
print(diff_output)
292310
return False
293311

@@ -296,10 +314,16 @@ def main():
296314
parser = argparse.ArgumentParser(description="Compare two table files using word-wise diff")
297315
parser.add_argument("file1", help="First file (expected)")
298316
parser.add_argument("file2", help="Second file (actual)")
317+
parser.add_argument(
318+
"--cell-width",
319+
type=int,
320+
default=CELL_WIDTH,
321+
help=f"Width of each table cell (default: {CELL_WIDTH})",
322+
)
299323

300324
args = parser.parse_args()
301325

302-
success = compare_files(args.file1, args.file2)
326+
success = compare_files(args.file1, args.file2, args.cell_width)
303327
sys.exit(0 if success else 1)
304328

305329

0 commit comments

Comments
 (0)