Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 100 additions & 73 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,6 @@ def load_stream(self, stream):
Please refer the doc of inner function `generate_data_batches` for more details how
this function works in overall.
"""
import pyarrow as pa
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
Expand All @@ -1535,14 +1534,17 @@ def generate_data_batches(batches):
This function must avoid materializing multiple Arrow RecordBatches into memory at the
same time. And data chunks from the same grouping key should appear sequentially.
"""
for batch in batches:
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
]
key_series = [data_pandas[o] for o in self.key_offsets]
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas)
import pandas as pd

def row_stream():
for batch in batches:
for row in batch.to_pandas().itertuples(index=False):
batch_key = tuple(row[s] for s in self.key_offsets)
yield (batch_key, row)

for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]):
df = pd.DataFrame([row for _, row in group_rows])
yield (batch_key, [df[col] for col in df.columns])

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
Expand Down Expand Up @@ -1599,7 +1601,6 @@ def __init__(
self.init_key_offsets = None

def load_stream(self, stream):
import pyarrow as pa
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
Expand All @@ -1616,49 +1617,81 @@ def generate_data_batches(batches):
into the data generator.
"""

def flatten_columns(cur_batch, col_name):
state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))
state_field_names = [
state_column.type[i].name for i in range(state_column.type.num_fields)
]
state_field_arrays = [
state_column.field(i) for i in range(state_column.type.num_fields)
]
table_from_fields = pa.Table.from_arrays(
state_field_arrays, names=state_field_names
)
return table_from_fields
import pandas as pd
from itertools import tee

"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
data generator. Rows in the same batch may have different grouping keys.
"""
for batch in batches:
flatten_state_table = flatten_columns(batch, "inputData")
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_state_table.itercolumns())
]

flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_init_table.itercolumns())
]
key_series = [data_pandas[o] for o in self.key_offsets]
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets]
batches_gent_1, batches_gent_2 = tee(batches)
columns_map: dict[str, Optional[list[str]]] = {"inputData": None, "initState": None}

def data_stream(batch_iter, field_name, key_offsets):
nonlocal columns_map
for batch in batch_iter:
if columns_map[field_name] is None:
columns_map[field_name] = [
f.name
for f in batch.column(batch.schema.get_field_index(field_name)).type
]

data_df = pd.DataFrame(
{
f.name: batch.column(batch.schema.get_field_index(field_name))
.field(f.name)
.to_pandas()
for f in batch.column(batch.schema.get_field_index(field_name)).type
}
)

if any(s.empty for s in key_series):
# If any row is empty, assign batch_key using init_key_series
batch_key = tuple(s[0] for s in init_key_series)
else:
# If all rows are non-empty, create batch_key from key_series
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas, init_data_pandas)
for row in data_df.itertuples(index=False):
batch_key = tuple(row[o] for o in key_offsets)
yield batch_key, row

def groupby_pair(gen1, gen2, keyfunc):
"""
Iterate over two sorted generators in parallel, grouped by the same key.
Yields (key, group1, group2), where groups are iterators.
"""

def safe_next(group_iter):
return next(group_iter, (None, iter(())))

g1, g2 = groupby(gen1, key=keyfunc), groupby(gen2, key=keyfunc)
k1, grp1 = safe_next(g1)
k2, grp2 = safe_next(g2)

while k1 is not None or k2 is not None:
key = min(k for k in (k1, k2) if k is not None)
yield key, grp1 if k1 == key else iter(()), grp2 if k2 == key else iter(())
if k1 == key:
k1, grp1 = safe_next(g1)
if k2 == key:
k2, grp2 = safe_next(g2)

for batch_key, input_data_iterator, init_state_iterator in groupby_pair(
data_stream(batches_gent_1, "inputData", self.key_offsets),
data_stream(batches_gent_2, "initState", self.init_key_offsets),
keyfunc=lambda x: x[0],
):
input_data_pandas = pd.DataFrame(
[d for _, d in input_data_iterator], columns=columns_map["inputData"]
)
init_state_pandas = pd.DataFrame(
[d for _, d in init_state_iterator], columns=columns_map["initState"]
)

yield (
batch_key,
[input_data_pandas[col] for col in input_data_pandas.columns],
[init_state_pandas[col] for col in init_state_pandas.columns],
)

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
Expand Down Expand Up @@ -1711,15 +1744,15 @@ def generate_data_batches(batches):
same time. And data chunks from the same grouping key should appear sequentially.
"""
for batch in batches:
DataRow = Row(*(batch.schema.names))
DataRow = Row(*batch.schema.names)

# This is supposed to be the same.
batch_key = tuple(batch[o][0].as_py() for o in self.key_offsets)
# Iterate row by row without converting the whole batch
num_cols = batch.num_columns
for row_idx in range(batch.num_rows):
row = DataRow(
*(batch.column(i)[row_idx].as_py() for i in range(batch.num_columns))
)
yield (batch_key, row)
# build the key for this row
row_key = tuple(batch[o][row_idx].as_py() for o in self.key_offsets)
row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols)))
yield row_key, row

_batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
Expand Down Expand Up @@ -1807,44 +1840,38 @@ def extract_rows(cur_batch, col_name, key_offsets):
table = pa.Table.from_arrays(data_field_arrays, names=data_field_names)

if table.num_rows == 0:
return (None, iter([]))
return iter([])
else:
batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets)

rows = []
for row_idx in range(table.num_rows):
row = DataRow(
*(table.column(i)[row_idx].as_py() for i in range(table.num_columns))
)
rows.append(row)
def row_iter():
for row_idx in range(table.num_rows):
batch_key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets)
row = DataRow(
*(
table.column(i)[row_idx].as_py()
for i in range(table.num_columns)
)
)
yield (batch_key, row)

return (batch_key, iter(rows))
return row_iter()

"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
data generator. All rows in the same batch may have different grouping keys.
"""
for batch in batches:
(input_batch_key, input_data_iter) = extract_rows(
batch, "inputData", self.key_offsets
)
(init_batch_key, init_state_iter) = extract_rows(
batch, "initState", self.init_key_offsets
)

if input_batch_key is None:
batch_key = init_batch_key
else:
batch_key = input_batch_key
input_data_iter = extract_rows(batch, "inputData", self.key_offsets)
init_state_iter = extract_rows(batch, "initState", self.init_key_offsets)

for init_state_row in init_state_iter:
for batch_key, init_state_row in init_state_iter:
yield (batch_key, None, init_state_row)

for input_data_row in input_data_iter:
for batch_key, input_data_row in input_data_iter:
yield (batch_key, input_data_row, None)

_batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TransformWithStateInPySparkPythonRunner(

private var pandasWriter: BaseStreamingArrowWriter = _

// Grouping multiple keys into one arrow batch
override protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
Expand All @@ -75,23 +76,23 @@ class TransformWithStateInPySparkPythonRunner(
pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch)
}

if (inputIterator.hasNext) {
val startData = dataOut.size()
// Sending all rows to Python as arrow batch with mixed keys
val startData = dataOut.size()
while (inputIterator.hasNext) {
val next = inputIterator.next()
val dataIter = next._2

while (dataIter.hasNext) {
val dataRow = dataIter.next()
pandasWriter.writeRow(dataRow)
}
pandasWriter.finalizeCurrentArrowBatch()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
true
} else {
super[PythonArrowInput].close()
false
}

pandasWriter.finalizeCurrentArrowBatch()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
super[PythonArrowInput].close()
false
}
}

Expand Down Expand Up @@ -125,6 +126,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(

private var pandasWriter: BaseStreamingArrowWriter = _

// Grouping multiple keys into one arrow batch
override protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
Expand All @@ -134,8 +136,9 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch)
}

if (inputIterator.hasNext) {
val startData = dataOut.size()
// Sending all rows to Python as arrow batch with mixed keys
val startData = dataOut.size()
while (inputIterator.hasNext) {
// a new grouping key with data & init state iter
val next = inputIterator.next()
val dataIter = next._2
Expand All @@ -150,14 +153,13 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
else InternalRow.empty
pandasWriter.writeRow(InternalRow(dataRow, initRow))
}
pandasWriter.finalizeCurrentArrowBatch()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
true
} else {
super[PythonArrowInput].close()
false
}

pandasWriter.finalizeCurrentArrowBatch()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
super[PythonArrowInput].close()
false
}
}

Expand Down