@@ -41,11 +41,12 @@ def to_df(r):
41
41
42
42
43
43
class StreamingResult :
44
- def __init__ (self , c_result , conn , result_func ):
44
+ def __init__ (self , c_result , conn , result_func , supports_record_batch ):
45
45
self ._result = c_result
46
46
self ._result_func = result_func
47
47
self ._conn = conn
48
48
self ._exhausted = False
49
+ self ._supports_record_batch = supports_record_batch
49
50
50
51
def fetch (self ):
51
52
"""Fetch next chunk of streaming results"""
@@ -80,15 +81,182 @@ def __enter__(self):
80
81
return self
81
82
82
83
def __exit__ (self , exc_type , exc_val , exc_tb ):
83
- pass
84
+ self .cancel ()
85
+
86
+ def close (self ):
87
+ self .cancel ()
84
88
85
89
def cancel (self ):
86
- self ._exhausted = True
90
+ if not self ._exhausted :
91
+ self ._exhausted = True
92
+ try :
93
+ self ._conn .streaming_cancel_query (self ._result )
94
+ except Exception as e :
95
+ raise RuntimeError (f"Failed to cancel streaming query: { str (e )} " ) from e
87
96
88
- try :
89
- self ._conn .streaming_cancel_query (self ._result )
90
- except Exception as e :
91
- raise RuntimeError (f"Failed to cancel streaming query: { str (e )} " ) from e
97
+ def record_batch (self , rows_per_batch : int = 1000000 ) -> pa .RecordBatchReader :
98
+ """
99
+ Create a PyArrow RecordBatchReader from this StreamingResult.
100
+
101
+ This method requires that the StreamingResult was created with arrow format.
102
+ It wraps the streaming result with ChdbRecordBatchReader to provide efficient
103
+ batching with configurable batch sizes.
104
+
105
+ Args:
106
+ rows_per_batch (int): Number of rows per batch. Defaults to 1000000.
107
+
108
+ Returns:
109
+ pa.RecordBatchReader: PyArrow RecordBatchReader for efficient streaming
110
+
111
+ Raises:
112
+ ValueError: If the StreamingResult was not created with arrow format
113
+ """
114
+ if not self ._supports_record_batch :
115
+ raise ValueError (
116
+ "record_batch() can only be used with arrow format. "
117
+ "Please use format='Arrow' when calling send_query."
118
+ )
119
+
120
+ chdb_reader = ChdbRecordBatchReader (self , rows_per_batch )
121
+ return pa .RecordBatchReader .from_batches (chdb_reader .schema (), chdb_reader )
122
+
123
+
124
+ class ChdbRecordBatchReader :
125
+ """
126
+ A PyArrow RecordBatchReader wrapper for chdb StreamingResult.
127
+
128
+ This class provides an efficient way to read large result sets as PyArrow RecordBatches
129
+ with configurable batch sizes to optimize memory usage and performance.
130
+ """
131
+
132
+ def __init__ (self , chdb_stream_result , batch_size_rows ):
133
+ self ._stream_result = chdb_stream_result
134
+ self ._schema = None
135
+ self ._closed = False
136
+ self ._pending_batches = []
137
+ self ._accumulator = []
138
+ self ._batch_size_rows = batch_size_rows
139
+ self ._current_rows = 0
140
+ self ._first_batch = None
141
+ self ._first_batch_consumed = True
142
+ self ._schema = self .schema ()
143
+
144
+ def schema (self ):
145
+ if self ._schema is None :
146
+ # Get the first chunk to determine schema
147
+ chunk = self ._stream_result .fetch ()
148
+ if chunk is not None :
149
+ arrow_bytes = chunk .bytes ()
150
+ reader = pa .RecordBatchFileReader (arrow_bytes )
151
+ self ._schema = reader .schema
152
+
153
+ table = reader .read_all ()
154
+ if table .num_rows > 0 :
155
+ batches = table .to_batches ()
156
+ self ._first_batch = batches [0 ]
157
+ if len (batches ) > 1 :
158
+ self ._pending_batches = batches [1 :]
159
+ self ._first_batch_consumed = False
160
+ else :
161
+ self ._first_batch = None
162
+ self ._first_batch_consumed = True
163
+ else :
164
+ self ._schema = pa .schema ([])
165
+ self ._first_batch = None
166
+ self ._first_batch_consumed = True
167
+ self ._closed = True
168
+ return self ._schema
169
+
170
+ def read_next_batch (self ):
171
+ if self ._accumulator :
172
+ result = self ._accumulator .pop (0 )
173
+ return result
174
+
175
+ if self ._closed :
176
+ raise StopIteration
177
+
178
+ while True :
179
+ batch = None
180
+
181
+ # 1. Return the first batch if not consumed yet
182
+ if not self ._first_batch_consumed :
183
+ self ._first_batch_consumed = True
184
+ batch = self ._first_batch
185
+
186
+ # 2. Check pending batches from current chunk
187
+ elif self ._pending_batches :
188
+ batch = self ._pending_batches .pop (0 )
189
+
190
+ # 3. Fetch new chunk from chdb stream
191
+ else :
192
+ chunk = self ._stream_result .fetch ()
193
+ if chunk is None :
194
+ # No more data - return accumulated batches if any
195
+ break
196
+
197
+ arrow_bytes = chunk .bytes ()
198
+ if not arrow_bytes :
199
+ continue
200
+
201
+ reader = pa .RecordBatchFileReader (arrow_bytes )
202
+ table = reader .read_all ()
203
+
204
+ if table .num_rows > 0 :
205
+ batches = table .to_batches ()
206
+ batch = batches [0 ]
207
+ if len (batches ) > 1 :
208
+ self ._pending_batches = batches [1 :]
209
+ else :
210
+ continue
211
+
212
+ # Process the batch if we got one
213
+ if batch is not None :
214
+ self ._accumulator .append (batch )
215
+ self ._current_rows += batch .num_rows
216
+
217
+ # If accumulated enough rows, return combined batch
218
+ if self ._current_rows >= self ._batch_size_rows :
219
+ if len (self ._accumulator ) == 1 :
220
+ result = self ._accumulator .pop (0 )
221
+ else :
222
+ if hasattr (pa , 'concat_batches' ):
223
+ result = pa .concat_batches (self ._accumulator )
224
+ self ._accumulator = []
225
+ else :
226
+ result = self ._accumulator .pop (0 )
227
+
228
+ self ._current_rows = 0
229
+ return result
230
+
231
+ # End of stream - return any accumulated batches
232
+ if self ._accumulator :
233
+ if len (self ._accumulator ) == 1 :
234
+ result = self ._accumulator .pop (0 )
235
+ else :
236
+ if hasattr (pa , 'concat_batches' ):
237
+ result = pa .concat_batches (self ._accumulator )
238
+ self ._accumulator = []
239
+ else :
240
+ result = self ._accumulator .pop (0 )
241
+
242
+ self ._current_rows = 0
243
+ self ._closed = True
244
+ return result
245
+
246
+ # No more data
247
+ self ._closed = True
248
+ raise StopIteration
249
+
250
+ def close (self ):
251
+ if not self ._closed :
252
+ self ._stream_result .close ()
253
+ self ._closed = True
254
+
255
+ def __iter__ (self ):
256
+ return self
257
+
258
+ def __next__ (self ):
259
+ return self .read_next_batch ()
92
260
93
261
94
262
class Connection :
@@ -112,12 +280,13 @@ def query(self, query: str, format: str = "CSV") -> Any:
112
280
113
281
def send_query (self , query : str , format : str = "CSV" ) -> StreamingResult :
114
282
lower_output_format = format .lower ()
283
+ supports_record_batch = lower_output_format == "arrow"
115
284
result_func = _process_result_format_funs .get (lower_output_format , lambda x : x )
116
285
if lower_output_format in _arrow_format :
117
286
format = "Arrow"
118
287
119
288
c_stream_result = self ._conn .send_query (query , format )
120
- return StreamingResult (c_stream_result , self ._conn , result_func )
289
+ return StreamingResult (c_stream_result , self ._conn , result_func , supports_record_batch )
121
290
122
291
def close (self ) -> None :
123
292
# print("close")
0 commit comments