From f903bfb2ea00fe64762dc4475bd9b75fa5d789dd Mon Sep 17 00:00:00 2001 From: zeruibao Date: Sun, 24 Aug 2025 15:24:53 -0700 Subject: [PATCH 1/9] save --- python/pyspark/sql/pandas/serializers.py | 65 ++++++++++--------- ...nsformWithStateInPySparkPythonRunner.scala | 38 ++++++----- 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 73546d2320bd3..eb483e20f1916 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1535,14 +1535,18 @@ def generate_data_batches(batches): This function must avoid materializing multiple Arrow RecordBatches into memory at the same time. And data chunks from the same grouping key should appear sequentially. """ - for batch in batches: - data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) - ] - key_series = [data_pandas[o] for o in self.key_offsets] - batch_key = tuple(s[0] for s in key_series) - yield (batch_key, data_pandas) + import pandas as pd + + def row_stream(): + for batch in batches: + data_pandas = pa.Table.from_batches([batch]).to_pandas() + for row in data_pandas.itertuples(index=False): + batch_keys = tuple(row[s] for s in self.key_offsets) + yield (batch_keys, row) + + for batch_keys, group_rows in groupby(row_stream(), key=lambda x: x[0]): + df = pd.DataFrame([row for _, row in group_rows]) + yield (batch_keys, [df[col] for col in df.columns]) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) @@ -1616,6 +1620,7 @@ def generate_data_batches(batches): into the data generator. """ + import pandas as pd def flatten_columns(cur_batch, col_name): state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) state_field_names = [ @@ -1635,30 +1640,28 @@ def flatten_columns(cur_batch, col_name): .add("inputData", dataSchema) .add("initState", initStateSchema) We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. All rows in the same batch have the same grouping key. + data generator. Rows in the same batch may have diiferent grouping keys. """ - for batch in batches: - flatten_state_table = flatten_columns(batch, "inputData") - data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_state_table.itercolumns()) - ] - - flatten_init_table = flatten_columns(batch, "initState") - init_data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_init_table.itercolumns()) - ] - key_series = [data_pandas[o] for o in self.key_offsets] - init_key_series = [init_data_pandas[o] for o in self.init_key_offsets] - - if any(s.empty for s in key_series): - # If any row is empty, assign batch_key using init_key_series - batch_key = tuple(s[0] for s in init_key_series) - else: - # If all rows are non-empty, create batch_key from key_series - batch_key = tuple(s[0] for s in key_series) - yield (batch_key, data_pandas, init_data_pandas) + def row_stream(): + for batch in batches: + flatten_data_pd = flatten_columns(batch, "inputData").to_pandas() + flatten_init_pd = flatten_columns(batch, "initState").to_pandas() + for row_data, row_init in zip( + flatten_data_pd.itertuples(index=False), + flatten_init_pd.itertuples(index=False)): + if len(row_data) == 0: + # If row is empty, assign batch_key using row_init + batch_keys = tuple(row_init[s] for s in self.key_offsets) + else: + batch_keys = tuple(row_data[s] for s in self.key_offsets) + yield (batch_keys, row_data, row_init) + + for batch_keys, group_rows in groupby(row_stream(), key=lambda x: x[0]): + data_pandas = pd.DataFrame([row_data for _, row_data, _ in group_rows]) + init_data_pandas = pd.DataFrame([row_init for _, _, row_init in group_rows]) + yield (batch_keys, + [data_pandas[col] for col in data_pandas.columns], + [init_data_pandas[col] for col in init_data_pandas.columns]) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala index 51dc179c901ab..f9f415dcac876 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala @@ -66,6 +66,7 @@ class TransformWithStateInPySparkPythonRunner( private var pandasWriter: BaseStreamingArrowWriter = _ + // Grouping multiple keys into one arrow batch override protected def writeNextBatchToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -75,8 +76,9 @@ class TransformWithStateInPySparkPythonRunner( pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) } - if (inputIterator.hasNext) { - val startData = dataOut.size() + // Sending all rows to Python as arrow batch with mixed keys + val startData = dataOut.size() + while (inputIterator.hasNext) { val next = inputIterator.next() val dataIter = next._2 @@ -84,14 +86,13 @@ class TransformWithStateInPySparkPythonRunner( val dataRow = dataIter.next() pandasWriter.writeRow(dataRow) } - pandasWriter.finalizeCurrentArrowBatch() - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData - true - } else { - super[PythonArrowInput].close() - false } + + pandasWriter.finalizeCurrentArrowBatch() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + super[PythonArrowInput].close() + false } } @@ -125,6 +126,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner( private var pandasWriter: BaseStreamingArrowWriter = _ + // Grouping multiple keys into one arrow batch override protected def writeNextBatchToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -134,8 +136,9 @@ class TransformWithStateInPySparkPythonInitialStateRunner( pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) } - if (inputIterator.hasNext) { - val startData = dataOut.size() + // Sending all rows to Python as arrow batch with mixed keys + val startData = dataOut.size() + while (inputIterator.hasNext) { // a new grouping key with data & init state iter val next = inputIterator.next() val dataIter = next._2 @@ -150,14 +153,13 @@ class TransformWithStateInPySparkPythonInitialStateRunner( else InternalRow.empty pandasWriter.writeRow(InternalRow(dataRow, initRow)) } - pandasWriter.finalizeCurrentArrowBatch() - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData - true - } else { - super[PythonArrowInput].close() - false } + + pandasWriter.finalizeCurrentArrowBatch() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + super[PythonArrowInput].close() + false } } From 0b71419c83ca3c7a4ba8d87712846b00b04d0cbe Mon Sep 17 00:00:00 2001 From: zeruibao Date: Sun, 24 Aug 2025 18:18:11 -0700 Subject: [PATCH 2/9] save --- python/pyspark/sql/pandas/serializers.py | 46 ++++++++++-------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index eb483e20f1916..005d6a86f9366 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1539,8 +1539,7 @@ def generate_data_batches(batches): def row_stream(): for batch in batches: - data_pandas = pa.Table.from_batches([batch]).to_pandas() - for row in data_pandas.itertuples(index=False): + for _, row in batch.to_pandas().iterrows(): batch_keys = tuple(row[s] for s in self.key_offsets) yield (batch_keys, row) @@ -1603,7 +1602,6 @@ def __init__( self.init_key_offsets = None def load_stream(self, stream): - import pyarrow as pa from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) @@ -1621,19 +1619,6 @@ def generate_data_batches(batches): """ import pandas as pd - def flatten_columns(cur_batch, col_name): - state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) - state_field_names = [ - state_column.type[i].name for i in range(state_column.type.num_fields) - ] - state_field_arrays = [ - state_column.field(i) for i in range(state_column.type.num_fields) - ] - table_from_fields = pa.Table.from_arrays( - state_field_arrays, names=state_field_names - ) - return table_from_fields - """ The arrow batch is written in the schema: schema: StructType = new StructType() @@ -1644,21 +1629,26 @@ def flatten_columns(cur_batch, col_name): """ def row_stream(): for batch in batches: - flatten_data_pd = flatten_columns(batch, "inputData").to_pandas() - flatten_init_pd = flatten_columns(batch, "initState").to_pandas() - for row_data, row_init in zip( - flatten_data_pd.itertuples(index=False), - flatten_init_pd.itertuples(index=False)): - if len(row_data) == 0: - # If row is empty, assign batch_key using row_init - batch_keys = tuple(row_init[s] for s in self.key_offsets) + for _, row in batch.to_pandas().iterrows(): + input_data = row["inputData"] + init_state = row["initState"] + + key_series = [list(input_data.values())[o] for o in self.key_offsets] + init_key_series = [list(init_state.values())[o] for o in self.init_key_offsets] + + if any(s is None for s in key_series): + # If any row is empty, assign batch_key using init_key_series + batch_keys = tuple(s for s in init_key_series) else: - batch_keys = tuple(row_data[s] for s in self.key_offsets) - yield (batch_keys, row_data, row_init) + batch_keys = tuple(s for s in key_series) + + yield batch_keys, input_data, init_state for batch_keys, group_rows in groupby(row_stream(), key=lambda x: x[0]): - data_pandas = pd.DataFrame([row_data for _, row_data, _ in group_rows]) - init_data_pandas = pd.DataFrame([row_init for _, _, row_init in group_rows]) + data_pairs = [(d, i) for _, d, i in group_rows] + data_pandas = pd.DataFrame([d for d, _ in data_pairs]).dropna(how="all") + init_data_pandas = pd.DataFrame([i for _, i in data_pairs]).dropna(how="all") + yield (batch_keys, [data_pandas[col] for col in data_pandas.columns], [init_data_pandas[col] for col in init_data_pandas.columns]) From 764450a7f9480080a928c7f76be08525a3a9971c Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 14:30:22 -0700 Subject: [PATCH 3/9] save --- python/pyspark/sql/pandas/serializers.py | 76 ++++++++++++------- .../test_pandas_transform_with_state.py | 2 + 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 005d6a86f9366..711b6efceecb5 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1539,7 +1539,7 @@ def generate_data_batches(batches): def row_stream(): for batch in batches: - for _, row in batch.to_pandas().iterrows(): + for row in batch.to_pandas().itertuples(index=False): batch_keys = tuple(row[s] for s in self.key_offsets) yield (batch_keys, row) @@ -1619,39 +1619,63 @@ def generate_data_batches(batches): """ import pandas as pd + from itertools import tee """ The arrow batch is written in the schema: schema: StructType = new StructType() .add("inputData", dataSchema) .add("initState", initStateSchema) We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. Rows in the same batch may have diiferent grouping keys. + data generator. Rows in the same batch may have different grouping keys. """ - def row_stream(): - for batch in batches: - for _, row in batch.to_pandas().iterrows(): - input_data = row["inputData"] - init_state = row["initState"] - - key_series = [list(input_data.values())[o] for o in self.key_offsets] - init_key_series = [list(init_state.values())[o] for o in self.init_key_offsets] - - if any(s is None for s in key_series): - # If any row is empty, assign batch_key using init_key_series - batch_keys = tuple(s for s in init_key_series) - else: - batch_keys = tuple(s for s in key_series) - - yield batch_keys, input_data, init_state - - for batch_keys, group_rows in groupby(row_stream(), key=lambda x: x[0]): - data_pairs = [(d, i) for _, d, i in group_rows] - data_pandas = pd.DataFrame([d for d, _ in data_pairs]).dropna(how="all") - init_data_pandas = pd.DataFrame([i for _, i in data_pairs]).dropna(how="all") - yield (batch_keys, - [data_pandas[col] for col in data_pandas.columns], - [init_data_pandas[col] for col in init_data_pandas.columns]) + batches_gent_1, batches_gent_2 = tee(batches) + def input_data_stream(): + for batch in batches_gent_1: + input_data_df = pd.DataFrame( + {f.name: batch.column(batch.schema.get_field_index("inputData")).field(f.name).to_pandas() + for f in batch.column(batch.schema.get_field_index("inputData")).type}) + + for row in input_data_df.itertuples(index=False): + batch_key = [row[o] for o in self.key_offsets] + yield batch_key, row + + def init_state_stream(): + for batch in batches_gent_2: + init_state_df = pd.DataFrame( + {f.name: batch.column(batch.schema.get_field_index("initState")).field(f.name).to_pandas() + for f in batch.column(batch.schema.get_field_index("initState")).type}) + + for row in init_state_df.itertuples(index=False): + batch_key = [row[o] for o in self.init_key_offsets] + yield batch_key, row + + def groupby_pair(gen1, gen2, keyfunc): + g1, g2 = groupby(gen1, key=keyfunc), groupby(gen2, key=keyfunc) + next1 = next(g1, (None, iter(()))) + next2 = next(g2, (None, iter(()))) + + k1, grp1 = next1 + k2, grp2 = next2 + + while k1 is not None or k2 is not None: + key = min(k for k in (k1, k2) if k is not None) + yield key, grp1 if k1 == key else iter(()), grp2 if k2 == key else iter(()) + if k1 == key: k1, grp1 = next(g1, (None, iter(()))) + if k2 == key: k2, grp2 = next(g2, (None, iter(()))) + + for batch_key, input_data_iterator, init_state_iterator in groupby_pair(input_data_stream(), + init_state_stream(), + keyfunc=lambda x: x[0]): + input_data_pandas = pd.DataFrame([d for _, d in input_data_iterator]) + init_state_pandas = pd.DataFrame([d for _, d in init_state_iterator]) + + # print(f"@@@ input_data_pandas: {input_data_pandas}") + # print(f"@@@ init_state_pandas: {init_state_pandas}") + + yield (batch_key, + [input_data_pandas[col] for col in input_data_pandas.columns], + [init_state_pandas[col] for col in init_state_pandas.columns]) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index d3bda545e1c99..e9116b1f89a91 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1913,6 +1913,8 @@ def conf(cls): cfg.set("spark.sql.session.timeZone", "UTC") # TODO SPARK-49046 this config is to stop query from FEB sink gracefully cfg.set("spark.sql.streaming.noDataMicroBatches.enabled", "false") + # cfg.set("spark.sql.execution.pyspark.udf.faulthandler.enabled", "true") + # cfg.set("spark.python.worker.faulthandler.enabled", "true") return cfg From 7c2744fef2455fe87666b80fb2ba995a9e300904 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 14:45:25 -0700 Subject: [PATCH 4/9] save --- python/pyspark/sql/pandas/serializers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 711b6efceecb5..c32fbb759f8bb 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1728,15 +1728,15 @@ def generate_data_batches(batches): same time. And data chunks from the same grouping key should appear sequentially. """ for batch in batches: - DataRow = Row(*(batch.schema.names)) + DataRow = Row(*batch.schema.names) - # This is supposed to be the same. - batch_key = tuple(batch[o][0].as_py() for o in self.key_offsets) + # Iterate row by row without converting the whole batch + num_cols = batch.num_columns for row_idx in range(batch.num_rows): - row = DataRow( - *(batch.column(i)[row_idx].as_py() for i in range(batch.num_columns)) - ) - yield (batch_key, row) + # build the key for this row + row_key = tuple(batch[o][row_idx].as_py() for o in self.key_offsets) + row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) + yield row_key, row _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) From 97d166fcc959b05210db7de22d1c1fdb809f7492 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 14:57:11 -0700 Subject: [PATCH 5/9] save --- python/pyspark/sql/pandas/serializers.py | 35 +++++++++--------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index c32fbb759f8bb..078584ba72c26 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1670,9 +1670,6 @@ def groupby_pair(gen1, gen2, keyfunc): input_data_pandas = pd.DataFrame([d for _, d in input_data_iterator]) init_state_pandas = pd.DataFrame([d for _, d in init_state_iterator]) - # print(f"@@@ input_data_pandas: {input_data_pandas}") - # print(f"@@@ init_state_pandas: {init_state_pandas}") - yield (batch_key, [input_data_pandas[col] for col in input_data_pandas.columns], [init_state_pandas[col] for col in init_state_pandas.columns]) @@ -1824,18 +1821,17 @@ def extract_rows(cur_batch, col_name, key_offsets): table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) if table.num_rows == 0: - return (None, iter([])) + return iter([]) else: - batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets) - - rows = [] - for row_idx in range(table.num_rows): - row = DataRow( - *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) - ) - rows.append(row) + def row_iter(): + for row_idx in range(table.num_rows): + batch_key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) + row = DataRow( + *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) + ) + yield (batch_key, row) - return (batch_key, iter(rows)) + return row_iter() """ The arrow batch is written in the schema: @@ -1846,22 +1842,17 @@ def extract_rows(cur_batch, col_name, key_offsets): data generator. All rows in the same batch have the same grouping key. """ for batch in batches: - (input_batch_key, input_data_iter) = extract_rows( + input_data_iter = extract_rows( batch, "inputData", self.key_offsets ) - (init_batch_key, init_state_iter) = extract_rows( + init_state_iter = extract_rows( batch, "initState", self.init_key_offsets ) - if input_batch_key is None: - batch_key = init_batch_key - else: - batch_key = input_batch_key - - for init_state_row in init_state_iter: + for batch_key, init_state_row in init_state_iter: yield (batch_key, None, init_state_row) - for input_data_row in input_data_iter: + for batch_key, input_data_row in input_data_iter: yield (batch_key, input_data_row, None) _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) From cc6fa0c60d8974e9addbab13763177334c12c27e Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 15:00:02 -0700 Subject: [PATCH 6/9] save --- python/pyspark/sql/pandas/serializers.py | 56 ++++++++++++------- .../test_pandas_transform_with_state.py | 2 - 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 078584ba72c26..7b1b56da18df6 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1620,6 +1620,7 @@ def generate_data_batches(batches): import pandas as pd from itertools import tee + """ The arrow batch is written in the schema: schema: StructType = new StructType() @@ -1630,11 +1631,17 @@ def generate_data_batches(batches): """ batches_gent_1, batches_gent_2 = tee(batches) + def input_data_stream(): for batch in batches_gent_1: input_data_df = pd.DataFrame( - {f.name: batch.column(batch.schema.get_field_index("inputData")).field(f.name).to_pandas() - for f in batch.column(batch.schema.get_field_index("inputData")).type}) + { + f.name: batch.column(batch.schema.get_field_index("inputData")) + .field(f.name) + .to_pandas() + for f in batch.column(batch.schema.get_field_index("inputData")).type + } + ) for row in input_data_df.itertuples(index=False): batch_key = [row[o] for o in self.key_offsets] @@ -1643,8 +1650,13 @@ def input_data_stream(): def init_state_stream(): for batch in batches_gent_2: init_state_df = pd.DataFrame( - {f.name: batch.column(batch.schema.get_field_index("initState")).field(f.name).to_pandas() - for f in batch.column(batch.schema.get_field_index("initState")).type}) + { + f.name: batch.column(batch.schema.get_field_index("initState")) + .field(f.name) + .to_pandas() + for f in batch.column(batch.schema.get_field_index("initState")).type + } + ) for row in init_state_df.itertuples(index=False): batch_key = [row[o] for o in self.init_key_offsets] @@ -1661,18 +1673,22 @@ def groupby_pair(gen1, gen2, keyfunc): while k1 is not None or k2 is not None: key = min(k for k in (k1, k2) if k is not None) yield key, grp1 if k1 == key else iter(()), grp2 if k2 == key else iter(()) - if k1 == key: k1, grp1 = next(g1, (None, iter(()))) - if k2 == key: k2, grp2 = next(g2, (None, iter(()))) + if k1 == key: + k1, grp1 = next(g1, (None, iter(()))) + if k2 == key: + k2, grp2 = next(g2, (None, iter(()))) - for batch_key, input_data_iterator, init_state_iterator in groupby_pair(input_data_stream(), - init_state_stream(), - keyfunc=lambda x: x[0]): + for batch_key, input_data_iterator, init_state_iterator in groupby_pair( + input_data_stream(), init_state_stream(), keyfunc=lambda x: x[0] + ): input_data_pandas = pd.DataFrame([d for _, d in input_data_iterator]) init_state_pandas = pd.DataFrame([d for _, d in init_state_iterator]) - yield (batch_key, - [input_data_pandas[col] for col in input_data_pandas.columns], - [init_state_pandas[col] for col in init_state_pandas.columns]) + yield ( + batch_key, + [input_data_pandas[col] for col in input_data_pandas.columns], + [init_state_pandas[col] for col in init_state_pandas.columns], + ) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) @@ -1823,11 +1839,15 @@ def extract_rows(cur_batch, col_name, key_offsets): if table.num_rows == 0: return iter([]) else: + def row_iter(): for row_idx in range(table.num_rows): batch_key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) row = DataRow( - *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) + *( + table.column(i)[row_idx].as_py() + for i in range(table.num_columns) + ) ) yield (batch_key, row) @@ -1839,15 +1859,11 @@ def row_iter(): .add("inputData", dataSchema) .add("initState", initStateSchema) We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. All rows in the same batch have the same grouping key. + data generator. All rows in the same batch may have different grouping keys. """ for batch in batches: - input_data_iter = extract_rows( - batch, "inputData", self.key_offsets - ) - init_state_iter = extract_rows( - batch, "initState", self.init_key_offsets - ) + input_data_iter = extract_rows(batch, "inputData", self.key_offsets) + init_state_iter = extract_rows(batch, "initState", self.init_key_offsets) for batch_key, init_state_row in init_state_iter: yield (batch_key, None, init_state_row) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index e9116b1f89a91..d3bda545e1c99 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1913,8 +1913,6 @@ def conf(cls): cfg.set("spark.sql.session.timeZone", "UTC") # TODO SPARK-49046 this config is to stop query from FEB sink gracefully cfg.set("spark.sql.streaming.noDataMicroBatches.enabled", "false") - # cfg.set("spark.sql.execution.pyspark.udf.faulthandler.enabled", "true") - # cfg.set("spark.python.worker.faulthandler.enabled", "true") return cfg From 3772f367a1cd6160986db418627d162ae8d4a38e Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 16:33:47 -0700 Subject: [PATCH 7/9] save --- python/pyspark/sql/pandas/serializers.py | 59 ++++++++++++------------ 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 7b1b56da18df6..55b51c208f73c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1540,12 +1540,12 @@ def generate_data_batches(batches): def row_stream(): for batch in batches: for row in batch.to_pandas().itertuples(index=False): - batch_keys = tuple(row[s] for s in self.key_offsets) - yield (batch_keys, row) + batch_key = tuple(row[s] for s in self.key_offsets) + yield (batch_key, row) - for batch_keys, group_rows in groupby(row_stream(), key=lambda x: x[0]): + for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): df = pd.DataFrame([row for _, row in group_rows]) - yield (batch_keys, [df[col] for col in df.columns]) + yield (batch_key, [df[col] for col in df.columns]) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) @@ -1631,35 +1631,28 @@ def generate_data_batches(batches): """ batches_gent_1, batches_gent_2 = tee(batches) - - def input_data_stream(): - for batch in batches_gent_1: - input_data_df = pd.DataFrame( + columns_map: dict[str, Optional[list[str]]] = {"inputData": None, "initState": None} + + def data_stream(batch_iter, field_name, key_offsets): + nonlocal columns_map + for batch in batch_iter: + if columns_map[field_name] is None: + columns_map[field_name] = [ + f.name + for f in batch.column(batch.schema.get_field_index(field_name)).type + ] + + data_df = pd.DataFrame( { - f.name: batch.column(batch.schema.get_field_index("inputData")) + f.name: batch.column(batch.schema.get_field_index(field_name)) .field(f.name) .to_pandas() - for f in batch.column(batch.schema.get_field_index("inputData")).type + for f in batch.column(batch.schema.get_field_index(field_name)).type } ) - for row in input_data_df.itertuples(index=False): - batch_key = [row[o] for o in self.key_offsets] - yield batch_key, row - - def init_state_stream(): - for batch in batches_gent_2: - init_state_df = pd.DataFrame( - { - f.name: batch.column(batch.schema.get_field_index("initState")) - .field(f.name) - .to_pandas() - for f in batch.column(batch.schema.get_field_index("initState")).type - } - ) - - for row in init_state_df.itertuples(index=False): - batch_key = [row[o] for o in self.init_key_offsets] + for row in data_df.itertuples(index=False): + batch_key = tuple(row[o] for o in key_offsets) yield batch_key, row def groupby_pair(gen1, gen2, keyfunc): @@ -1679,10 +1672,16 @@ def groupby_pair(gen1, gen2, keyfunc): k2, grp2 = next(g2, (None, iter(()))) for batch_key, input_data_iterator, init_state_iterator in groupby_pair( - input_data_stream(), init_state_stream(), keyfunc=lambda x: x[0] + data_stream(batches_gent_1, "inputData", self.key_offsets), + data_stream(batches_gent_2, "initState", self.init_key_offsets), + keyfunc=lambda x: x[0], ): - input_data_pandas = pd.DataFrame([d for _, d in input_data_iterator]) - init_state_pandas = pd.DataFrame([d for _, d in init_state_iterator]) + input_data_pandas = pd.DataFrame( + [d for _, d in input_data_iterator], columns=columns_map["inputData"] + ) + init_state_pandas = pd.DataFrame( + [d for _, d in init_state_iterator], columns=columns_map["initState"] + ) yield ( batch_key, From d154239a8f79275f3b0342a4f891fb0e75d22670 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 16:56:42 -0700 Subject: [PATCH 8/9] save --- python/pyspark/sql/pandas/serializers.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 55b51c208f73c..2f1dc5b6f5669 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1656,20 +1656,25 @@ def data_stream(batch_iter, field_name, key_offsets): yield batch_key, row def groupby_pair(gen1, gen2, keyfunc): - g1, g2 = groupby(gen1, key=keyfunc), groupby(gen2, key=keyfunc) - next1 = next(g1, (None, iter(()))) - next2 = next(g2, (None, iter(()))) + """ + Iterate over two sorted generators in parallel, grouped by the same key. + Yields (key, group1, group2), where groups are iterators (empty if no items for the key). + """ + + def safe_next(group_iter): + return next(group_iter, (None, iter(()))) - k1, grp1 = next1 - k2, grp2 = next2 + g1, g2 = groupby(gen1, key=keyfunc), groupby(gen2, key=keyfunc) + k1, grp1 = safe_next(g1) + k2, grp2 = safe_next(g2) while k1 is not None or k2 is not None: key = min(k for k in (k1, k2) if k is not None) yield key, grp1 if k1 == key else iter(()), grp2 if k2 == key else iter(()) if k1 == key: - k1, grp1 = next(g1, (None, iter(()))) + k1, grp1 = safe_next(g1) if k2 == key: - k2, grp2 = next(g2, (None, iter(()))) + k2, grp2 = safe_next(g2) for batch_key, input_data_iterator, init_state_iterator in groupby_pair( data_stream(batches_gent_1, "inputData", self.key_offsets), From 6ba1b75d1a4f941a94f4b91015ab6b95935baec1 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Mon, 25 Aug 2025 22:26:07 -0700 Subject: [PATCH 9/9] save --- python/pyspark/sql/pandas/serializers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 2f1dc5b6f5669..a3a303390ead8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1520,7 +1520,6 @@ def load_stream(self, stream): Please refer the doc of inner function `generate_data_batches` for more details how this function works in overall. """ - import pyarrow as pa from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) @@ -1658,7 +1657,7 @@ def data_stream(batch_iter, field_name, key_offsets): def groupby_pair(gen1, gen2, keyfunc): """ Iterate over two sorted generators in parallel, grouped by the same key. - Yields (key, group1, group2), where groups are iterators (empty if no items for the key). + Yields (key, group1, group2), where groups are iterators. """ def safe_next(group_iter):