diff --git a/README.md b/README.md
index 0656710e..6ad54c68 100644
--- a/README.md
+++ b/README.md
@@ -1082,6 +1082,44 @@ dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuf
+
+
+ ✅ Multi-Sample Transform datasets while Streaming 🔗
+
+
+Sometimes you need to return a sub-sample batch for a given batch while adding subtle variations to the samples. The multi-sample feature allows you to apply multi-sample transformation while streaming, without the need to store intermediate results.
+
+```python
+def transform_fn(x, sample_idx):
+ """
+ Apply different rotation for each sample based on sample_idx.
+ """
+
+ angles = [0, 15, -15, 30]
+ angle = angles[sample_idx % len(angles)]
+
+ torch_transform = transforms.Compose([
+ transforms.Resize((256, 256)),
+ transforms.Lambda(lambda x: transforms.functional.rotate(x, angle)), # apply rotation
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+ ])
+ return torch_transform(x)
+
+dataset = StreamingDataset(
+data_dir,
+cache_dir=str(cache_dir),
+shuffle=False,
+transform=[transform_fn],
+sample_count=4 # Generate 4 transformed samples per input
+)
+```
+
+
+
✅ Split datasets for train, val, test 🔗
diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py
index 1dca6c5b..9a049363 100644
--- a/src/litdata/streaming/dataset.py
+++ b/src/litdata/streaming/dataset.py
@@ -13,6 +13,7 @@
import logging
import os
+from inspect import signature
from time import time
from typing import Any, Callable, Optional, Union
@@ -62,6 +63,7 @@ def __init__(
index_path: Optional[str] = None,
force_override_state_dict: bool = False,
transform: Optional[Union[Callable, list[Callable]]] = None,
+ sample_count: int = 1,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
@@ -89,6 +91,7 @@ def __init__(
If `index_path` is a full file path, it will use that directly.
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
transform: Optional transformation function or list of functions to apply to each item in the dataset.
+ sample_count: Number of samples to return for each index access.
"""
_check_version_and_prompt_upgrade(__version__)
@@ -202,12 +205,27 @@ def __init__(
self.storage_options = storage_options
self.session_options = session_options
self.max_pre_download = max_pre_download
+ self.sample_count = sample_count
if transform is not None:
transform = transform if isinstance(transform, list) else [transform]
for t in transform:
if not callable(t):
raise ValueError(f"Transform should be a callable. Found {t}")
self.transform = transform
+
+ # define invalid transform conditions for multisample case
+ invalid_transform = self.sample_count > 1 and (
+ not hasattr(self, "transform")
+ or len(self.transform) > 1
+ or "sample_idx" not in signature(self.transform[0]).parameters
+ )
+ if invalid_transform:
+ logger.warning(
+ "Invalid transform configuration detected. "
+ "Either no transform, multiple transforms, or missing `sample_idx` parameter. "
+ "Reverting `sample_count` to 1 and returning data as-is."
+ )
+ self.sample_count = 1
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache
@property
@@ -282,7 +300,7 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)
def __len__(self) -> int:
- return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
+ return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) * self.sample_count
def set_batch_size(self, batch_size: int) -> None:
self.batch_size = batch_size
@@ -324,7 +342,7 @@ def __iter__(self) -> "StreamingDataset":
self.worker_intervals = workers_intervals[worker_rank]
# The max number of samples to return from `__next__` (in worker)
- self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
+ self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.sample_count
# Handle restart
if self._state_dict:
@@ -407,7 +425,9 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
# replay the indexes for the current chunks
interval = self.worker_intervals[self.worker_next_chunk_index]
- current_indexes = np.arange(interval[1], interval[2])
+
+ # multiply the interval by the sample_count for multisample case
+ current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count)
# re-shuffle the indexes
current_indexes = self.shuffler(
@@ -424,6 +444,17 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
self.worker_next_chunk_index += 1
def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
+ # Deflate index for multisample case
+ if self.sample_count > 1:
+ if isinstance(index, int):
+ sample_idx = index % self.sample_count
+ index = index // self.sample_count
+ elif isinstance(index, ChunkedIndex):
+ sample_idx = index.index % self.sample_count
+ index.index = index.index // self.sample_count
+ else:
+ raise ValueError("Slices are not supported when using `sample_count > 1`.")
+
if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
@@ -437,10 +468,11 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
_my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices]
item = self.cache[index]
+
if hasattr(self, "transform"):
if isinstance(self.transform, list):
for transform_fn in self.transform:
- item = transform_fn(item)
+ item = transform_fn(item) if self.sample_count == 1 else transform_fn(item, sample_idx)
else:
item = self.transform(item)
@@ -476,7 +508,9 @@ def __next__(self) -> Any:
# `next_worker_chunks_index` is the index of the chunk that we will be working on now
interval = self.worker_intervals[self.worker_next_chunk_index]
- current_indexes = np.arange(interval[1], interval[2])
+
+ # current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor)
+ current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count)
assert self.shuffler is not None
assert self.num_chunks is not None
diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py
index ddb517e4..2c2e9c92 100644
--- a/tests/streaming/test_dataloader.py
+++ b/tests/streaming/test_dataloader.py
@@ -1,3 +1,4 @@
+import logging
import os
import sys
@@ -496,3 +497,126 @@ def test_dataloader_dataset_transform_inheritance(tmpdir, shuffle):
# Verify that the transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i * 2, f"Expected {i * 2}, got {item}"
+
+
+# Define a simple transform function
+def multisample_transform_fn(x, sample_idx, *args, **kwargs):
+ """A simple transform function that doubles the input."""
+ return x * sample_idx
+
+
+def test_dataloader_dataset_transform_multisample(tmpdir):
+ """Test if the dataset's transform is applied correctly with dataloader."""
+ # Create a simple dataset
+ # Create directories for cache and data
+ cache_dir = os.path.join(tmpdir, "cache_dir")
+ data_dir = os.path.join(tmpdir, "data_dir")
+ os.makedirs(cache_dir)
+ os.makedirs(data_dir)
+
+ # Create a dataset with 100 items, 20 items per chunk
+ cache = Cache(str(data_dir), chunk_size=20)
+ for i in range(100):
+ cache[i] = i
+ cache.done()
+ cache.merge()
+
+ dataset = StreamingDataset(
+ data_dir, cache_dir=str(cache_dir), shuffle=False, transform=multisample_transform_fn, sample_count=3
+ )
+ dataset_length = len(dataset)
+ assert dataset_length == 300
+
+ # ACT
+ dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)
+
+ complete_data = []
+ for batch in dl:
+ complete_data.extend(batch)
+
+ # ASSERT
+ # Verify that the multisample transform is applied correctly
+ for i, item in enumerate(complete_data):
+ if i % 3 == 0:
+ assert item == (i // 3) * 0, f"Expected {i * 0}, got {item}"
+ elif i % 3 == 1:
+ assert item == (i // 3) * 1, f"Expected {i * 1}, got {item}"
+ else:
+ assert item == (i // 3) * 2, f"Expected {i * 2}, got {item}"
+
+
+# Define simple transform functions
+def transform_fn_sq(x, sample_idx):
+ """A simple transform function that doubles the input."""
+ return x * sample_idx
+
+
+def transform_fn_add(x, sample_idx):
+ """A simple transform function that adds the sample_idx to the input."""
+ return x + sample_idx
+
+
+def transform_fn_no_sample_idx(x):
+ """A simple transform function that doubles the input."""
+ return x
+
+
+def test_dataloader_dataset_transform_invalid_config(tmpdir, caplog):
+ """Test if the dataset's transform is applied correctly with dataloader."""
+ # Create a simple dataset
+ # Create directories for cache and data
+ cache_dir = os.path.join(tmpdir, "cache_dir")
+ data_dir = os.path.join(tmpdir, "data_dir")
+ os.makedirs(cache_dir)
+ os.makedirs(data_dir)
+
+ # Create a dataset with 100 items, 20 items per chunk
+ cache = Cache(str(data_dir), chunk_size=20)
+ for i in range(100):
+ cache[i] = i
+ cache.done()
+ cache.merge()
+
+ # Verify that logger warning happens when transform is not given
+ with caplog.at_level(logging.WARNING):
+ dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4)
+
+ assert "Invalid transform configuration detected." in caplog.text
+ dataset_length = len(dataset)
+ assert dataset_length == 100
+
+ # Verify that logger warning happens when multiple transforms are given
+ with caplog.at_level(logging.WARNING):
+ dataset = StreamingDataset(
+ data_dir,
+ cache_dir=str(cache_dir),
+ shuffle=False,
+ sample_count=4,
+ transform=[transform_fn_sq, transform_fn_add],
+ )
+
+ assert "Invalid transform configuration detected." in caplog.text
+ dataset_length = len(dataset)
+ assert dataset_length == 100
+
+ # Verify that logger warning happens when sample_idx parameter is missing
+ with caplog.at_level(logging.WARNING):
+ dataset = StreamingDataset(
+ data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx
+ )
+
+ assert "Invalid transform configuration detected." in caplog.text
+ dataset_length = len(dataset)
+ assert dataset_length == 100
+
+ # ACT
+ dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)
+
+ complete_data = []
+ for batch in dl:
+ complete_data.extend(batch)
+
+ # ASSERT
+ # Verify that the multisample transform is applied correctly
+ for i, item in enumerate(complete_data):
+ assert item == i, f"Expected {i}, got {item}"
diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py
index c3cd4637..f5aea6fa 100644
--- a/tests/streaming/test_dataset.py
+++ b/tests/streaming/test_dataset.py
@@ -1813,3 +1813,106 @@ def transform(self, x, *args, **kwargs):
# Verify that the transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i * 2, f"Expected {i * 2}, got {item}"
+
+
+def test_dataset_transform_multisample(tmpdir):
+ """Test if the dataset transform is applied correctly."""
+ # Create a simple dataset
+ # Create directories for cache and data
+ cache_dir = os.path.join(tmpdir, "cache_dir")
+ data_dir = os.path.join(tmpdir, "data_dir")
+ os.makedirs(cache_dir)
+ os.makedirs(data_dir)
+
+ # Create a dataset with 100 items, 20 items per chunk
+ cache = Cache(str(data_dir), chunk_size=20)
+ for i in range(100):
+ cache[i] = i
+ cache.done()
+ cache.merge()
+
+ # Define simple transform functions
+ def transform_fn_sq(x, sample_idx):
+ """A simple transform function that doubles the input."""
+ return x * sample_idx
+
+ sample_count = 3
+ dataset = StreamingDataset(
+ data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, sample_count=sample_count
+ )
+ dataset_length = len(dataset)
+ assert dataset_length == 300
+
+ # ASSERT
+ # Verify that the transform functions are applied correctly
+ for i, item in enumerate(dataset):
+ assert item is not None
+ if i % sample_count == 0:
+ assert item == (i // sample_count) * 0, f"Expected {(i // sample_count) * 0}, got {item}"
+ elif i % sample_count == 1:
+ assert item == (i // sample_count) * 1, f"Expected {(i // sample_count) * 1}, got {item}"
+ else:
+ assert item == (i // sample_count) * 2, f"Expected {(i // sample_count) * 2}, got {item}"
+
+
+def test_dataset_transform_multisample_invalid_config(tmpdir, caplog):
+ """Test if the dataset raises an error when is_multisample is True but transform is not a list."""
+ # Create a simple dataset
+ # Create directories for cache and data
+ cache_dir = os.path.join(tmpdir, "cache_dir")
+ data_dir = os.path.join(tmpdir, "data_dir")
+ os.makedirs(cache_dir)
+ os.makedirs(data_dir)
+
+ # Define simple transform functions
+ def transform_fn_sq(x, sample_idx):
+ """A simple transform function that doubles the input."""
+ return x * sample_idx
+
+ def transform_fn_add(x, sample_idx):
+ """A simple transform function that adds the sample_idx to the input."""
+ return x + sample_idx
+
+ def transform_fn_no_sample_idx(x):
+ """A simple transform function that misses the sample_idx parameter."""
+ return x
+
+ # Create a dataset with 100 items, 20 items per chunk
+ cache = Cache(str(data_dir), chunk_size=20)
+ for i in range(100):
+ cache[i] = i
+ cache.done()
+ cache.merge()
+
+ # ASSERT
+ # Verify that logger warning happens when transform is not given
+ with caplog.at_level(logging.WARNING):
+ dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4)
+
+ assert "Invalid transform configuration detected." in caplog.text
+ dataset_length = len(dataset)
+ assert dataset_length == 100
+
+ # Verify that logger warning happens when multiple transforms are given
+ with caplog.at_level(logging.WARNING):
+ dataset = StreamingDataset(
+ data_dir,
+ cache_dir=str(cache_dir),
+ shuffle=False,
+ sample_count=4,
+ transform=[transform_fn_sq, transform_fn_add],
+ )
+
+ assert "Invalid transform configuration detected." in caplog.text
+ dataset_length = len(dataset)
+ assert dataset_length == 100
+
+ # Verify that logger warning happens when sample_idx parameter is missing
+ with caplog.at_level(logging.WARNING):
+ dataset = StreamingDataset(
+ data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx
+ )
+
+ assert "Invalid transform configuration detected." in caplog.text
+ dataset_length = len(dataset)
+ assert dataset_length == 100