Skip to content

Commit 2d4f024

Browse files
authored
Add PyArrow RecordBatchReader support for streaming data export (#354)
* feat: add PyArrow RecordBatchReader support for streaming data export * fix: fix format * fix: fix tests
1 parent d42d5eb commit 2d4f024

9 files changed

+404
-17
lines changed

.github/workflows/build_linux_arm64_wheels-gh.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ jobs:
122122
- name: Run tests
123123
run: |
124124
python3 -m pip install dist/*.whl
125-
python3 -m pip install pandas pyarrow psutil
125+
python3 -m pip install pandas pyarrow psutil deltalake
126126
python3 -c "import chdb; res = chdb.query('select 1112222222,555', 'CSV'); print(res)"
127127
make test
128128
continue-on-error: false

.github/workflows/build_linux_x86_wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tests
121121
run: |
122122
python3 -m pip install dist/*.whl
123-
python3 -m pip install pandas pyarrow psutil
123+
python3 -m pip install pandas pyarrow psutil deltalake
124124
python3 -c "import chdb; res = chdb.query('select 1112222222,555', 'CSV'); print(res)"
125125
make test
126126
continue-on-error: false

.github/workflows/build_macos_arm64_wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ jobs:
124124
- name: Run tests
125125
run: |
126126
python3 -m pip install dist/*.whl
127-
python3 -m pip install pandas pyarrow psutil
127+
python3 -m pip install pandas pyarrow psutil deltalake
128128
python3 -c "import chdb; res = chdb.query('select 1112222222,555', 'CSV'); print(res)"
129129
make test
130130
continue-on-error: false

.github/workflows/build_macos_x86_wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ jobs:
124124
- name: Run tests
125125
run: |
126126
python3 -m pip install dist/*.whl
127-
python3 -m pip install pandas pyarrow psutil
127+
python3 -m pip install pandas pyarrow psutil deltalake
128128
python3 -c "import chdb; res = chdb.query('select 1112222222,555', 'CSV'); print(res)"
129129
make test
130130
continue-on-error: false

README-zh.md

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,37 @@ while True:
184184
if chunk is None:
185185
break
186186
if rows_cnt > 0:
187-
stream_result.cancel()
187+
stream_result.close()
188188
break
189189
rows_cnt += chunk.rows_read()
190190

191191
print(rows_cnt) # 65409
192192

193+
# 示例4:使用PyArrow RecordBatchReader进行批量导出以及与其他库集成
194+
import pyarrow as pa
195+
from deltalake import write_deltalake
196+
197+
# 获取arrow格式的流式结果
198+
stream_result = sess.send_query("SELECT * FROM numbers(100000)", "Arrow")
199+
200+
# 创建自定义批次大小的RecordBatchReader(默认rows_per_batch=1000000)
201+
batch_reader = stream_result.record_batch(rows_per_batch=10000)
202+
203+
# 将RecordBatchReader与外部库(如Delta Lake)一起使用
204+
write_deltalake(
205+
table_or_uri="./my_delta_table",
206+
data=batch_reader,
207+
mode="overwrite"
208+
)
209+
210+
stream_result.close()
211+
193212
sess.close()
194213
```
195214

196-
参见: [test_streaming_query.py](tests/test_streaming_query.py)
215+
**重要提示**:使用流式查询时,如果`StreamingResult`没有被完全消耗(由于错误或提前终止),必须显式调用`stream_result.close()`来释放资源,或使用`with`语句进行自动清理。否则可能会阻塞后续查询。
216+
217+
参见: [test_streaming_query.py](tests/test_streaming_query.py)[test_arrow_record_reader_deltalake.py](tests/test_arrow_record_reader_deltalake.py)
197218
</details>
198219

199220
更多示例,请参见 [examples](examples)[tests](tests)

README.md

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,37 @@ while True:
262262
if chunk is None:
263263
break
264264
if rows_cnt > 0:
265-
stream_result.cancel()
265+
stream_result.close()
266266
break
267267
rows_cnt += chunk.rows_read()
268268
269269
print(rows_cnt) # 65409
270270
271+
# Example 4: Using PyArrow RecordBatchReader for batch export and integration with other libraries
272+
import pyarrow as pa
273+
from deltalake import write_deltalake
274+
275+
# Get streaming result in arrow format
276+
stream_result = sess.send_query("SELECT * FROM numbers(100000)", "Arrow")
277+
278+
# Create RecordBatchReader with custom batch size (default rows_per_batch=1000000)
279+
batch_reader = stream_result.record_batch(rows_per_batch=10000)
280+
281+
# Use RecordBatchReader with external libraries like Delta Lake
282+
write_deltalake(
283+
table_or_uri="./my_delta_table",
284+
data=batch_reader,
285+
mode="overwrite"
286+
)
287+
288+
stream_result.close()
289+
271290
sess.close()
272291
```
273292

274-
For more details, see [test_streaming_query.py](tests/test_streaming_query.py).
293+
**Important Note**: When using streaming queries, if the `StreamingResult` is not fully consumed (due to errors or early termination), you must explicitly call `stream_result.close()` to release resources, or use the `with` statement for automatic cleanup. Failure to do so may block subsequent queries.
294+
295+
For more details, see [test_streaming_query.py](tests/test_streaming_query.py) and [test_arrow_record_reader_deltalake.py](tests/test_arrow_record_reader_deltalake.py).
275296
</details>
276297

277298

chdb/state/sqlitelike.py

Lines changed: 177 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ def to_df(r):
4141

4242

4343
class StreamingResult:
44-
def __init__(self, c_result, conn, result_func):
44+
def __init__(self, c_result, conn, result_func, supports_record_batch):
4545
self._result = c_result
4646
self._result_func = result_func
4747
self._conn = conn
4848
self._exhausted = False
49+
self._supports_record_batch = supports_record_batch
4950

5051
def fetch(self):
5152
"""Fetch next chunk of streaming results"""
@@ -80,15 +81,182 @@ def __enter__(self):
8081
return self
8182

8283
def __exit__(self, exc_type, exc_val, exc_tb):
83-
pass
84+
self.cancel()
85+
86+
def close(self):
87+
self.cancel()
8488

8589
def cancel(self):
86-
self._exhausted = True
90+
if not self._exhausted:
91+
self._exhausted = True
92+
try:
93+
self._conn.streaming_cancel_query(self._result)
94+
except Exception as e:
95+
raise RuntimeError(f"Failed to cancel streaming query: {str(e)}") from e
8796

88-
try:
89-
self._conn.streaming_cancel_query(self._result)
90-
except Exception as e:
91-
raise RuntimeError(f"Failed to cancel streaming query: {str(e)}") from e
97+
def record_batch(self, rows_per_batch: int = 1000000) -> pa.RecordBatchReader:
98+
"""
99+
Create a PyArrow RecordBatchReader from this StreamingResult.
100+
101+
This method requires that the StreamingResult was created with arrow format.
102+
It wraps the streaming result with ChdbRecordBatchReader to provide efficient
103+
batching with configurable batch sizes.
104+
105+
Args:
106+
rows_per_batch (int): Number of rows per batch. Defaults to 1000000.
107+
108+
Returns:
109+
pa.RecordBatchReader: PyArrow RecordBatchReader for efficient streaming
110+
111+
Raises:
112+
ValueError: If the StreamingResult was not created with arrow format
113+
"""
114+
if not self._supports_record_batch:
115+
raise ValueError(
116+
"record_batch() can only be used with arrow format. "
117+
"Please use format='Arrow' when calling send_query."
118+
)
119+
120+
chdb_reader = ChdbRecordBatchReader(self, rows_per_batch)
121+
return pa.RecordBatchReader.from_batches(chdb_reader.schema(), chdb_reader)
122+
123+
124+
class ChdbRecordBatchReader:
125+
"""
126+
A PyArrow RecordBatchReader wrapper for chdb StreamingResult.
127+
128+
This class provides an efficient way to read large result sets as PyArrow RecordBatches
129+
with configurable batch sizes to optimize memory usage and performance.
130+
"""
131+
132+
def __init__(self, chdb_stream_result, batch_size_rows):
133+
self._stream_result = chdb_stream_result
134+
self._schema = None
135+
self._closed = False
136+
self._pending_batches = []
137+
self._accumulator = []
138+
self._batch_size_rows = batch_size_rows
139+
self._current_rows = 0
140+
self._first_batch = None
141+
self._first_batch_consumed = True
142+
self._schema = self.schema()
143+
144+
def schema(self):
145+
if self._schema is None:
146+
# Get the first chunk to determine schema
147+
chunk = self._stream_result.fetch()
148+
if chunk is not None:
149+
arrow_bytes = chunk.bytes()
150+
reader = pa.RecordBatchFileReader(arrow_bytes)
151+
self._schema = reader.schema
152+
153+
table = reader.read_all()
154+
if table.num_rows > 0:
155+
batches = table.to_batches()
156+
self._first_batch = batches[0]
157+
if len(batches) > 1:
158+
self._pending_batches = batches[1:]
159+
self._first_batch_consumed = False
160+
else:
161+
self._first_batch = None
162+
self._first_batch_consumed = True
163+
else:
164+
self._schema = pa.schema([])
165+
self._first_batch = None
166+
self._first_batch_consumed = True
167+
self._closed = True
168+
return self._schema
169+
170+
def read_next_batch(self):
171+
if self._accumulator:
172+
result = self._accumulator.pop(0)
173+
return result
174+
175+
if self._closed:
176+
raise StopIteration
177+
178+
while True:
179+
batch = None
180+
181+
# 1. Return the first batch if not consumed yet
182+
if not self._first_batch_consumed:
183+
self._first_batch_consumed = True
184+
batch = self._first_batch
185+
186+
# 2. Check pending batches from current chunk
187+
elif self._pending_batches:
188+
batch = self._pending_batches.pop(0)
189+
190+
# 3. Fetch new chunk from chdb stream
191+
else:
192+
chunk = self._stream_result.fetch()
193+
if chunk is None:
194+
# No more data - return accumulated batches if any
195+
break
196+
197+
arrow_bytes = chunk.bytes()
198+
if not arrow_bytes:
199+
continue
200+
201+
reader = pa.RecordBatchFileReader(arrow_bytes)
202+
table = reader.read_all()
203+
204+
if table.num_rows > 0:
205+
batches = table.to_batches()
206+
batch = batches[0]
207+
if len(batches) > 1:
208+
self._pending_batches = batches[1:]
209+
else:
210+
continue
211+
212+
# Process the batch if we got one
213+
if batch is not None:
214+
self._accumulator.append(batch)
215+
self._current_rows += batch.num_rows
216+
217+
# If accumulated enough rows, return combined batch
218+
if self._current_rows >= self._batch_size_rows:
219+
if len(self._accumulator) == 1:
220+
result = self._accumulator.pop(0)
221+
else:
222+
if hasattr(pa, 'concat_batches'):
223+
result = pa.concat_batches(self._accumulator)
224+
self._accumulator = []
225+
else:
226+
result = self._accumulator.pop(0)
227+
228+
self._current_rows = 0
229+
return result
230+
231+
# End of stream - return any accumulated batches
232+
if self._accumulator:
233+
if len(self._accumulator) == 1:
234+
result = self._accumulator.pop(0)
235+
else:
236+
if hasattr(pa, 'concat_batches'):
237+
result = pa.concat_batches(self._accumulator)
238+
self._accumulator = []
239+
else:
240+
result = self._accumulator.pop(0)
241+
242+
self._current_rows = 0
243+
self._closed = True
244+
return result
245+
246+
# No more data
247+
self._closed = True
248+
raise StopIteration
249+
250+
def close(self):
251+
if not self._closed:
252+
self._stream_result.close()
253+
self._closed = True
254+
255+
def __iter__(self):
256+
return self
257+
258+
def __next__(self):
259+
return self.read_next_batch()
92260

93261

94262
class Connection:
@@ -112,12 +280,13 @@ def query(self, query: str, format: str = "CSV") -> Any:
112280

113281
def send_query(self, query: str, format: str = "CSV") -> StreamingResult:
114282
lower_output_format = format.lower()
283+
supports_record_batch = lower_output_format == "arrow"
115284
result_func = _process_result_format_funs.get(lower_output_format, lambda x: x)
116285
if lower_output_format in _arrow_format:
117286
format = "Arrow"
118287

119288
c_stream_result = self._conn.send_query(query, format)
120-
return StreamingResult(c_stream_result, self._conn, result_func)
289+
return StreamingResult(c_stream_result, self._conn, result_func, supports_record_batch)
121290

122291
def close(self) -> None:
123292
# print("close")

0 commit comments

Comments
 (0)