Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@
DatasetToZarr as DatasetToZarr,
)

__version__ = '0.11.0' # automatically synchronized to pyproject.toml
__version__ = '0.11.1' # automatically synchronized to pyproject.toml
37 changes: 26 additions & 11 deletions xarray_beam/_src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,9 @@ def to_zarr(
self,
path: str,
*,
zarr_chunks_per_shard: Mapping[str, int] | None = None,
zarr_chunks_per_shard: (
Mapping[str | types.EllipsisType, int] | None
) = None,
zarr_chunks: UnnormalizedChunks | None = None,
zarr_shards: UnnormalizedChunks | None = None,
zarr_format: int | None = None,
Expand All @@ -640,6 +642,9 @@ def to_zarr(
path: path to write to.
zarr_chunks_per_shard: If provided, write this dataset into Zarr shards,
each with at most this many Zarr chunks per shard (requires Zarr v3).
Dimensions not included in ``zarr_chunks_per_shard`` default to 1 chunk
per shard, unless a dict key of ellipsis (...) is used to indicate a
different default.
zarr_chunks: Explicit chunk sizes to use for storing data in Zarr, as an
alternative to specifying ``zarr_chunks_per_shard``. Zarr chunk sizes
must evenly divide the existing chunk sizes of this dataset.
Expand Down Expand Up @@ -675,22 +680,32 @@ def to_zarr(
)
if zarr_shards is None:
zarr_shards = self.chunks

chunks_per_shard = dict(zarr_chunks_per_shard)
if ... in chunks_per_shard:
default_cps = chunks_per_shard.pop(...)
else:
default_cps = 1

extra_keys = set(chunks_per_shard) - set(self.template.dims)
if extra_keys:
raise ValueError(
f'{zarr_chunks_per_shard=} includes keys that are not dimensions '
f' in template: {extra_keys}'
)

zarr_chunks = {}
for dim, existing_chunk_size in zarr_shards.items():
multiple = zarr_chunks_per_shard.get(dim)
if multiple is None:
raise ValueError(
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
f'{zarr_chunks_per_shard=}, which does not contain a value for '
f'dimension {dim!r}'
)
zarr_chunks[dim], remainder = divmod(existing_chunk_size, multiple)
for dim, shard_size in zarr_shards.items():
cps = chunks_per_shard.get(dim, default_cps)
chunk_size, remainder = divmod(shard_size, cps)
if remainder != 0:
raise ValueError(
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
f'{zarr_chunks_per_shard=}, which do not evenly divide into '
'chunks'
f'chunks. Computed chunk size for dimension {dim!r} is '
f'{chunk_size}, based on {cps} chunks per shard.'
)
zarr_chunks[dim] = chunk_size
elif zarr_chunks is None:
if zarr_shards is not None:
raise ValueError('cannot supply zarr_shards without zarr_chunks')
Expand Down
48 changes: 39 additions & 9 deletions xarray_beam/_src/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,40 @@ def test_to_zarr_chunks_per_shard(self):
self.assertEqual(opened['foo'].encoding['chunks'], (3,))
self.assertEqual(opened['foo'].encoding['shards'], (6,))

with self.subTest('default_one'):
temp_dir = self.create_tempdir().full_path
with beam.Pipeline() as p:
p |= beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={})
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 6})
self.assertEqual(opened['foo'].encoding['chunks'], (6,))
self.assertEqual(opened['foo'].encoding['shards'], (6,))

with self.subTest('ellipsis'):
temp_dir = self.create_tempdir().full_path
with beam.Pipeline() as p:
p |= beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={...: 2})
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 3})
self.assertEqual(opened['foo'].encoding['chunks'], (3,))
self.assertEqual(opened['foo'].encoding['shards'], (6,))

with self.subTest('ellipsis_with_dim'):
temp_dir = self.create_tempdir().full_path
ds2 = xarray.Dataset({'foo': (('x', 'y'), np.zeros((12, 10)))})
beam_ds2 = xbeam.Dataset.from_xarray(ds2, {'x': 6, 'y': 5})
with beam.Pipeline() as p:
p |= beam_ds2.to_zarr(
temp_dir, zarr_chunks_per_shard={'x': 3, ...: 1}
)
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds2, opened)
self.assertEqual(chunks, {'x': 2, 'y': 5})
self.assertEqual(opened['foo'].encoding['chunks'], (2, 5))
self.assertEqual(opened['foo'].encoding['shards'], (6, 5))

with self.subTest('explicit_shards'):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(24))})
Expand Down Expand Up @@ -738,25 +772,21 @@ def test_to_zarr_chunks_per_shard(self):
temp_dir, zarr_chunks_per_shard={'x': 2}, zarr_chunks={'x': 3}
)

with self.subTest('missing_dim_error'):
with self.subTest('extra_key_error'):
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
with self.assertRaisesWithLiteralMatch(
with self.assertRaisesRegex(
ValueError,
"cannot write a dataset with chunks {'x': 6} to Zarr with "
"zarr_chunks_per_shard={'y': 2}, which does not contain a value for "
"dimension 'x'",
'zarr_chunks_per_shard=.* includes keys that are not dimensions',
):
beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={'y': 2})

with self.subTest('uneven_division_error'):
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})
with self.assertRaisesWithLiteralMatch(
with self.assertRaisesRegex(
ValueError,
"cannot write a dataset with chunks {'x': 6} to Zarr with "
"zarr_chunks_per_shard={'x': 5}, which do not evenly divide into "
'chunks',
r'cannot write a dataset with chunks .*zarr_chunks_per_shard=.* which do not evenly divide',
):
beam_ds.to_zarr(temp_dir, zarr_chunks_per_shard={'x': 5})

Expand Down
Loading