Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1441,9 +1441,8 @@ class SparkConnectPlanner(
}

if (rel.hasData) {
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
Iterator(rel.getData.toByteArray),
TaskContext.get())
val (rows, structType) =
ArrowConverters.fromIPCStream(rel.getData.toByteArray, TaskContext.get())
if (structType == null) {
throw InvalidInputErrors.inputDataForLocalRelationNoSchema()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,109 @@ private[sql] object ArrowConverters extends Logging {
}
}

/**
* This is a class that converts input data in the form of a Byte array to InternalRow instances
* implementing the Iterator interface.
*
* The input data must be a valid Arrow IPC stream, this means that the first message is always
* the schema followed by N record batches.
*
* @param input Input Data
* @param context Task Context for Spark
*/
private[sql] class InternalRowIteratorFromIPCStream(
input: Array[Byte],
context: TaskContext) extends Iterator[InternalRow] {

// Keep all the resources we have opened in order, should be closed
// in reverse order finally.
private val resources = new ArrayBuffer[AutoCloseable]()

// Create an allocator used for all Arrow related memory.
protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}",
0,
Long.MaxValue)
resources.append(allocator)

private val reader = try {
new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
} catch {
case e: Exception =>
closeAll(resources.toSeq.reverse: _*)
throw new IllegalArgumentException(
s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
}
resources.append(reader)

private val root: VectorSchemaRoot = try {
reader.getVectorSchemaRoot
} catch {
case e: Exception =>
closeAll(resources.toSeq.reverse: _*)
throw new IllegalArgumentException(
s"Failed to read schema from IPC stream: ${e.getMessage}", e)
}
resources.append(root)

val schema: StructType = try {
ArrowUtils.fromArrowSchema(root.getSchema)
} catch {
case e: Exception =>
closeAll(resources.toSeq.reverse: _*)
throw new IllegalArgumentException(s"Failed to convert Arrow schema: ${e.getMessage}", e)
}

// TODO: wrap in exception
private var rowIterator: Iterator[InternalRow] = vectorSchemaRootToIter(root)

// Metrics to track batch processing
private var _batchesLoaded: Int = 0
private var _totalRowsProcessed: Long = 0L

if (context != null) {
context.addTaskCompletionListener[Unit] { _ =>
closeAll(resources.toSeq.reverse: _*)
}
}

// Public accessors for metrics
def batchesLoaded: Int = _batchesLoaded
def totalRowsProcessed: Long = _totalRowsProcessed

// Loads the next batch from the Arrow reader and returns true or
// false if the next batch could be loaded.
private def loadNextBatch(): Boolean = {
if (reader.loadNextBatch()) {
rowIterator = vectorSchemaRootToIter(root)
_batchesLoaded += 1
true
} else {
false
}
}

override def hasNext: Boolean = {
if (rowIterator.hasNext) {
true
} else {
if (!loadNextBatch()) {
false
} else {
hasNext
}
}
}

override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException("No more elements in iterator")
}
_totalRowsProcessed += 1
rowIterator.next()
}
}

/**
* An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should
* implement [[nextBatch]] to parse data from binary records.
Expand Down Expand Up @@ -382,6 +485,23 @@ private[sql] object ArrowConverters extends Logging {
(iterator, iterator.schema)
}

/**
* Creates an iterator from a Byte array to deserialize an Arrow IPC stream with exactly
* one schema and a varying number of record batches. Returns an iterator over the
* created InternalRow.
*/
private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
(Iterator[InternalRow], StructType) = {
fromIPCStreamWithIterator(input, context)
}

// Overloaded method for tests to access the iterator with metrics
private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context: TaskContext):
(InternalRowIteratorFromIPCStream, StructType) = {
val iterator = new InternalRowIteratorFromIPCStream(input, context)
(iterator, iterator.schema)
}

/**
* Convert an arrow batch container into an iterator of InternalRow.
*/
Expand Down
Loading