Skip to content

Commit 56b2f86

Browse files
frewsxcvborchero
andauthored
feat: Add support for Binary column types (#146)
Co-authored-by: Oliver Borchert <[email protected]>
1 parent 4b643a5 commit 56b2f86

File tree

8 files changed

+108
-0
lines changed

8 files changed

+108
-0
lines changed

dataframely/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .columns import (
2424
Any,
2525
Array,
26+
Binary,
2627
Bool,
2728
Categorical,
2829
Column,
@@ -77,6 +78,7 @@
7778
"read_parquet_metadata_schema",
7879
"read_parquet_metadata_collection",
7980
"Any",
81+
"Binary",
8082
"Bool",
8183
"Categorical",
8284
"Column",

dataframely/columns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._registry import column_from_dict
66
from .any import Any
77
from .array import Array
8+
from .binary import Binary
89
from .bool import Bool
910
from .categorical import Categorical
1011
from .datetime import Date, Datetime, Duration, Time
@@ -22,6 +23,7 @@
2223
"column_from_dict",
2324
"Any",
2425
"Array",
26+
"Binary",
2527
"Bool",
2628
"Categorical",
2729
"Date",

dataframely/columns/binary.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) QuantCo 2025-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from __future__ import annotations
5+
6+
import polars as pl
7+
8+
from dataframely._compat import pa, sa, sa_TypeEngine
9+
from dataframely.random import Generator
10+
11+
from ._base import Column
12+
from ._registry import register
13+
14+
15+
@register
16+
class Binary(Column):
17+
"""A column of binary values."""
18+
19+
@property
20+
def dtype(self) -> pl.DataType:
21+
return pl.Binary()
22+
23+
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
24+
if dialect.name == "mssql":
25+
return sa.VARBINARY()
26+
return sa.LargeBinary()
27+
28+
@property
29+
def pyarrow_dtype(self) -> pa.DataType:
30+
return pa.large_binary()
31+
32+
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
33+
return generator.sample_binary(
34+
n,
35+
min_bytes=0,
36+
max_bytes=32,
37+
null_probability=self._null_probability,
38+
)

dataframely/random.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,31 @@ def sample_string(
148148
pl.Series(samples, dtype=pl.String), null_probability
149149
)
150150

151+
def sample_binary(
152+
self,
153+
n: int = 1,
154+
*,
155+
min_bytes: int,
156+
max_bytes: int,
157+
null_probability: float = 0.0,
158+
) -> pl.Series:
159+
"""Sample a list of binary values in the specified length range.
160+
161+
Args:
162+
n: The number of binary values to sample.
163+
min_bytes: The minimum number of bytes for each value.
164+
max_bytes: The maximum number of bytes for each value.
165+
null_probability: The probability of an element being ``null``.
166+
167+
Returns:
168+
A series with ``n`` elements of dtype ``Binary``.
169+
"""
170+
lengths = self.numpy_generator.integers(min_bytes, max_bytes + 1, size=n)
171+
samples = [self.numpy_generator.bytes(length) for length in lengths]
172+
return self._apply_null_mask(
173+
pl.Series(samples, dtype=pl.Binary), null_probability
174+
)
175+
151176
def sample_choice(
152177
self,
153178
n: int = 1,

dataframely/testing/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
dc.UInt64,
2323
dc.String,
2424
dc.Categorical,
25+
dc.Binary,
2526
]
2627
INTEGER_COLUMN_TYPES: list[type[dc.Column]] = [
2728
dc.Integer,

tests/column_types/test_binary.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) QuantCo 2025-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import polars as pl
5+
import pytest
6+
7+
import dataframely as dy
8+
from dataframely.columns import Column
9+
10+
11+
class BinarySchema(dy.Schema):
12+
a = dy.Binary()
13+
14+
15+
@pytest.mark.parametrize(
16+
("column", "dtype", "is_valid"),
17+
[
18+
(dy.Binary(), pl.Binary(), True),
19+
(dy.Binary(), pl.String(), False),
20+
(dy.Binary(), pl.Null(), False),
21+
],
22+
)
23+
def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> None:
24+
assert column.validate_dtype(dtype) == is_valid

tests/columns/test_sql_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
("column", "datatype"),
1616
[
1717
(dy.Any(), "SQL_VARIANT"),
18+
(dy.Binary(), "VARBINARY(max)"),
1819
(dy.Bool(), "BIT"),
1920
(dy.Date(), "DATE"),
2021
(dy.Datetime(), "DATETIME2(6)"),
@@ -61,6 +62,7 @@ def test_mssql_datatype(column: Column, datatype: str) -> None:
6162
@pytest.mark.parametrize(
6263
("column", "datatype"),
6364
[
65+
(dy.Binary(), "BYTEA"),
6466
(dy.Bool(), "BOOLEAN"),
6567
(dy.Date(), "DATE"),
6668
(dy.Datetime(), "TIMESTAMP WITHOUT TIME ZONE"),

tests/test_random.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_seeding_nonconstant() -> None:
4040
lambda generator, n: generator.sample_float(n, min=0, max=5),
4141
lambda generator, n: generator.sample_string(n, regex="[abc]"),
4242
lambda generator, n: generator.sample_choice(n, choices=[1, 2, 3]),
43+
lambda generator, n: generator.sample_binary(n, min_bytes=1, max_bytes=10),
4344
lambda generator, n: generator.sample_time(n, min=dt.time(0, 0), max=None),
4445
lambda generator, n: generator.sample_date(
4546
n, min=dt.date(1970, 1, 1), max=None
@@ -75,6 +76,9 @@ def test_sample_correct_n(
7576
lambda generator, n, prob: generator.sample_choice(
7677
n, choices=[1, 2, 3], null_probability=prob
7778
),
79+
lambda generator, n, prob: generator.sample_binary(
80+
n, min_bytes=1, max_bytes=10, null_probability=prob
81+
),
7882
lambda generator, n, prob: generator.sample_time(
7983
n, min=dt.time(0, 0), max=None, null_probability=prob
8084
),
@@ -131,6 +135,16 @@ def test_sample_string(generator: Generator) -> None:
131135
assert (samples.str.len_bytes() == 2).all()
132136

133137

138+
def test_sample_binary(generator: Generator) -> None:
139+
samples = generator.sample_binary(100, min_bytes=1, max_bytes=10)
140+
assert (
141+
samples.to_frame("s").select(pl.col("s").bin.size("b") >= 1).to_series().all()
142+
)
143+
assert (
144+
samples.to_frame("s").select(pl.col("s").bin.size("b") <= 10).to_series().all()
145+
)
146+
147+
134148
def test_sample_choice(generator: Generator) -> None:
135149
samples = generator.sample_choice(100_000, choices=[1, 2, 3])
136150
assert np.allclose(

0 commit comments

Comments
 (0)