Skip to content

Commit 7bcfe4c

Browse files
committed
INTPYTHON-165 Refactor nested data handling
1 parent acb19b5 commit 7bcfe4c

File tree

3 files changed

+105
-156
lines changed

3 files changed

+105
-156
lines changed

bindings/python/pymongoarrow/context.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from bson.codec_options import DEFAULT_CODEC_OPTIONS
15-
from pyarrow import Table, timestamp
15+
from pyarrow import ListArray, StructArray, Table, timestamp
16+
from pyarrow.types import is_struct
1617

1718
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap
1819

@@ -52,6 +53,7 @@
5253
_BsonArrowTypes.date64: Date64Builder,
5354
_BsonArrowTypes.null: NullBuilder,
5455
}
56+
5557
except ImportError:
5658
pass
5759

@@ -90,35 +92,76 @@ def from_schema(cls, schema, codec_options=DEFAULT_CODEC_OPTIONS):
9092
builder_map = {}
9193
tzinfo = codec_options.tzinfo
9294
str_type_map = _get_internal_typemap(schema.typemap)
93-
for fname, ftype in str_type_map.items():
94-
builder_cls = _TYPE_TO_BUILDER_CLS[ftype]
95-
encoded_fname = fname.encode("utf-8")
96-
97-
# special-case initializing builders for parameterized types
98-
if builder_cls == DatetimeBuilder:
99-
arrow_type = schema.typemap[fname]
100-
if tzinfo is not None and arrow_type.tz is None:
101-
arrow_type = timestamp(arrow_type.unit, tz=tzinfo)
102-
builder_map[encoded_fname] = DatetimeBuilder(dtype=arrow_type)
103-
elif builder_cls == DocumentBuilder:
104-
arrow_type = schema.typemap[fname]
105-
builder_map[encoded_fname] = DocumentBuilder(arrow_type, tzinfo)
106-
elif builder_cls == ListBuilder:
107-
arrow_type = schema.typemap[fname]
108-
builder_map[encoded_fname] = ListBuilder(arrow_type, tzinfo)
109-
elif builder_cls == BinaryBuilder:
110-
subtype = schema.typemap[fname].subtype
111-
builder_map[encoded_fname] = BinaryBuilder(subtype)
112-
else:
113-
builder_map[encoded_fname] = builder_cls()
95+
_parse_types(str_type_map, builder_map, tzinfo)
11496
return cls(schema, builder_map)
11597

11698
def finish(self):
117-
arrays = []
118-
names = []
119-
for fname, builder in self.builder_map.items():
120-
arrays.append(builder.finish())
121-
names.append(fname.decode("utf-8"))
122-
if self.schema is not None:
123-
return Table.from_arrays(arrays=arrays, schema=self.schema.to_arrow())
124-
return Table.from_arrays(arrays=arrays, names=names)
99+
return self._finish(self.builder_map, self.schema)
100+
101+
@staticmethod
102+
def _finish(builder_map, schema):
103+
to_remove = []
104+
# Traverse the builder map right to left.
105+
for key, value in reversed(builder_map.items()):
106+
field = key.decode("utf-8")
107+
arr = value.finish()
108+
if isinstance(value, DocumentBuilder):
109+
full_names = [f"{field}.{name}" for name in arr]
110+
arrs = [builder_map[c.encode("utf-8")] for c in full_names]
111+
builder_map[field] = StructArray.from_arrays(arrs, names=arr)
112+
to_remove.extend(full_names)
113+
elif isinstance(value, ListBuilder):
114+
child = field + "[]"
115+
to_remove.append(child)
116+
builder_map[key] = ListArray.from_arrays(arr, builder_map.get(child, []))
117+
else:
118+
builder_map[key] = arr
119+
120+
for field in to_remove:
121+
key = field.encode("utf-8")
122+
if key in builder_map:
123+
del builder_map[key]
124+
125+
arrays = list(builder_map.values())
126+
if schema is not None:
127+
return Table.from_arrays(arrays=arrays, schema=schema.to_arrow())
128+
return Table.from_arrays(arrays=arrays, names=list(builder_map.keys()))
129+
130+
131+
def _parse_types(str_type_map, builder_map, tzinfo):
132+
for fname, (ftype, arrow_type) in str_type_map.items():
133+
builder_cls = _TYPE_TO_BUILDER_CLS[ftype]
134+
encoded_fname = fname.encode("utf-8")
135+
# special-case initializing builders for parameterized types
136+
if builder_cls == DatetimeBuilder:
137+
if tzinfo is not None and arrow_type.tz is None:
138+
arrow_type = timestamp(arrow_type.unit, tz=tzinfo) # noqa: PLW2901
139+
builder_map[encoded_fname] = DatetimeBuilder(dtype=arrow_type)
140+
elif builder_cls == DocumentBuilder:
141+
builder_map[encoded_fname] = DocumentBuilder()
142+
# construct a sub type map here
143+
sub_type_map = {}
144+
for i in range(arrow_type.num_fields):
145+
field = arrow_type[i]
146+
sub_name = f"{fname}.{field.name}"
147+
sub_type_map[sub_name] = field.type
148+
sub_type_map = _get_internal_typemap(sub_type_map)
149+
_parse_types(sub_type_map, builder_map, tzinfo)
150+
continue
151+
elif builder_cls == ListBuilder:
152+
builder_map[encoded_fname] = ListBuilder()
153+
if is_struct(arrow_type.value_type):
154+
# construct a sub type map here
155+
sub_type_map = {}
156+
for i in range(arrow_type.value_type.num_fields):
157+
field = arrow_type.value_type[i]
158+
sub_name = f"{fname}[].{field.name}"
159+
sub_type_map[sub_name] = field.type
160+
sub_type_map = _get_internal_typemap(sub_type_map)
161+
_parse_types(sub_type_map, builder_map, tzinfo)
162+
continue
163+
elif builder_cls == BinaryBuilder:
164+
subtype = arrow_type.subtype
165+
builder_map[encoded_fname] = BinaryBuilder(subtype)
166+
else:
167+
builder_map[encoded_fname] = builder_cls()

bindings/python/pymongoarrow/lib.pyx

Lines changed: 31 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
437437
it = missing_builders.begin()
438438
while it != missing_builders.end():
439439
builder = NullBuilder()
440+
key = dereference(it).first
440441
context.builder_map[key] = builder
441442
null_builder = builder
442443
for _ in range(count):
@@ -839,146 +840,50 @@ cdef class Decimal128Builder(_ArrayBuilderBase):
839840
return self.builder
840841

841842

842-
cdef object get_field_builder(object field, object tzinfo):
843-
""""Find the appropriate field builder given a pyarrow field"""
844-
cdef object field_builder
845-
cdef DataType field_type
846-
if isinstance(field, DataType):
847-
field_type = field
848-
else:
849-
field_type = field.type
850-
if _atypes.is_int32(field_type):
851-
field_builder = Int32Builder()
852-
elif _atypes.is_int64(field_type):
853-
field_builder = Int64Builder()
854-
elif _atypes.is_float64(field_type):
855-
field_builder = DoubleBuilder()
856-
elif _atypes.is_timestamp(field_type):
857-
if tzinfo and field_type.tz is None:
858-
field_type = timestamp(field_type.unit, tz=tzinfo)
859-
field_builder = DatetimeBuilder(field_type)
860-
elif _atypes.is_string(field_type):
861-
field_builder = StringBuilder()
862-
elif _atypes.is_large_string(field_type):
863-
field_builder = StringBuilder()
864-
elif _atypes.is_boolean(field_type):
865-
field_builder = BoolBuilder()
866-
elif _atypes.is_struct(field_type):
867-
field_builder = DocumentBuilder(field_type, tzinfo)
868-
elif _atypes.is_list(field_type):
869-
field_builder = ListBuilder(field_type, tzinfo)
870-
elif _atypes.is_large_list(field_type):
871-
field_builder = ListBuilder(field_type, tzinfo)
872-
elif _atypes.is_null(field_type):
873-
field_builder = NullBuilder()
874-
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.objectid:
875-
field_builder = ObjectIdBuilder()
876-
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.decimal128:
877-
field_builder = Decimal128Builder()
878-
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary:
879-
field_builder = BinaryBuilder(field_type.subtype)
880-
else:
881-
field_builder = StringBuilder()
882-
return field_builder
883-
884-
885-
cdef class DocumentBuilder(_ArrayBuilderBase):
843+
cdef class DocumentBuilder:
844+
"""The document builder stores a map of field names that can be retrieved as a set."""
886845
cdef:
887-
shared_ptr[CStructBuilder] builder
888-
object dtype
889-
object context
890-
891-
def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None):
892-
cdef StringBuilder field_builder
893-
cdef vector[shared_ptr[CArrayBuilder]] c_field_builders
894-
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
895-
896-
self.dtype = dtype
897-
if not _atypes.is_struct(dtype):
898-
raise ValueError("dtype must be a struct()")
899-
900-
self.context = context = PyMongoArrowContext(None, {})
901-
context.tzinfo = tzinfo
902-
builder_map = context.builder_map
903-
904-
for field in dtype:
905-
field_builder = <StringBuilder>get_field_builder(field, tzinfo)
906-
builder_map[field.name.encode('utf-8')] = field_builder
907-
c_field_builders.push_back(<shared_ptr[CArrayBuilder]>field_builder.builder)
846+
map[cstring, int32_t] field_map
908847

909-
self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders))
910-
self.type_marker = BSON_TYPE_DOCUMENT
911-
912-
@property
913-
def dtype(self):
914-
return self.dtype
915-
916-
cdef append_raw(self, const uint8_t * buf, size_t length):
917-
# Populate the child builders.
918-
process_raw_bson_stream(buf, length, self.context, None)
919-
# Append an element to the Struct. "All child-builders' Append method
920-
# must be called independently to maintain data-structure consistency."
921-
# Pass "true" for is_valid.
922-
self.builder.get().Append(True)
923-
924-
cpdef append(self, value):
925-
if not isinstance(value, bytes):
926-
value = bson.encode(value)
927-
self.append_raw(value, len(value))
928-
929-
cpdef append_null(self):
930-
self.builder.get().AppendNull()
848+
cdef add_field_raw(self, char * field):
849+
self.field_map[field] = 1
931850

932-
def __len__(self):
933-
return self.builder.get().length()
851+
cpdef add_field(self, field):
852+
self.field_map[field] = 1
934853

935-
cpdef finish(self):
936-
cdef shared_ptr[CArray] out
937-
with nogil:
938-
self.builder.get().Finish(&out)
939-
return pyarrow_wrap_array(out)
854+
def finish(self):
855+
it = self.field_map.begin()
856+
names = set()
857+
while it != self.field_map.end():
858+
names.add(dereference(it).first)
859+
preincrement(it)
860+
return names
940861

941-
cdef shared_ptr[CStructBuilder] unwrap(self):
942-
return self.builder
943862

863+
# NEXT STEP: test the DocumentBuilder and ListBuilder
944864

945865
cdef class ListBuilder(_ArrayBuilderBase):
866+
"""The list builder stores an int32 list of offsets and a counter with the current value.""""
946867
cdef:
947-
shared_ptr[CListBuilder] builder
948-
_ArrayBuilderBase child_builder
949-
object dtype
950-
object context
868+
shared_ptr[CInt32Builder] builder
869+
int32_t count
951870

952-
def __cinit__(self, DataType dtype, tzinfo=None, MemoryPool memory_pool=None, value_builder=None):
953-
cdef StringBuilder field_builder
871+
def __cinit__(self, MemoryPool memory_pool=None):
954872
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
955-
cdef shared_ptr[CArrayBuilder] grandchild_builder
956-
self.dtype = dtype
957-
if not (_atypes.is_list(dtype) or _atypes.is_large_list(dtype)):
958-
raise ValueError("dtype must be a list_() or large_list()")
959-
self.context = context = PyMongoArrowContext(None, {})
960-
self.context.tzinfo = tzinfo
961-
field_builder = <StringBuilder>get_field_builder(self.dtype.value_type, tzinfo)
962-
grandchild_builder = <shared_ptr[CArrayBuilder]>field_builder.builder
963-
self.child_builder = field_builder
964-
self.builder.reset(new CListBuilder(pool, grandchild_builder, pyarrow_unwrap_data_type(dtype)))
965-
self.type_marker = BSON_TYPE_ARRAY
873+
self.builder.reset(new CInt32Builder(pool))
874+
self.count = 0
966875

876+
cdef append_offset_raw(self):
877+
self.builder.get().Append(value)
967878

968-
@property
969-
def dtype(self):
970-
return self.dtype
879+
cpdef append_offset(self):
880+
self.builder.get().Append(value)
971881

972-
cdef append_raw(self, const uint8_t * buf, size_t length):
973-
# Append an element to the array.
974-
# arr_value_builder will be appended to by process_bson_stream.
975-
self.builder.get().Append(True)
976-
process_raw_bson_stream(buf, length, self.context, self.child_builder)
882+
cdef append_raw(self, int32_t value):
883+
self.count += 1
977884

978885
cpdef append(self, value):
979-
if not isinstance(value, bytes):
980-
value = bson.encode(value)
981-
self.append_raw(value, len(value))
886+
self.count += 1
982887

983888
cpdef append_null(self):
984889
self.builder.get().AppendNull()
@@ -987,12 +892,13 @@ cdef class ListBuilder(_ArrayBuilderBase):
987892
return self.builder.get().length()
988893

989894
cpdef finish(self):
895+
self.append_offset()
990896
cdef shared_ptr[CArray] out
991897
with nogil:
992898
self.builder.get().Finish(&out)
993899
return pyarrow_wrap_array(out)
994900

995-
cdef shared_ptr[CListBuilder] unwrap(self):
901+
cdef shared_ptr[CInt32Builder] unwrap(self):
996902
return self.builder
997903

998904

bindings/python/pymongoarrow/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _get_internal_typemap(typemap):
310310
for fname, ftype in typemap.items():
311311
for checker, internal_id in _TYPE_CHECKER_TO_INTERNAL_TYPE.items():
312312
if checker(ftype):
313-
internal_typemap[fname] = internal_id
313+
internal_typemap[fname] = (internal_id, ftype)
314314
break
315315

316316
if fname not in internal_typemap:

0 commit comments

Comments
 (0)