diff --git a/README.md b/README.md index 0656710e..6ad54c68 100644 --- a/README.md +++ b/README.md @@ -1082,6 +1082,44 @@ dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuf + +
+ ✅ Multi-Sample Transform datasets while Streaming 🔗 +  + +Sometimes you need to return a sub-sample batch for a given batch while adding subtle variations to the samples. The multi-sample feature allows you to apply multi-sample transformation while streaming, without the need to store intermediate results. + +```python +def transform_fn(x, sample_idx): + """ + Apply different rotation for each sample based on sample_idx. + """ + + angles = [0, 15, -15, 30] + angle = angles[sample_idx % len(angles)] + + torch_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.Lambda(lambda x: transforms.functional.rotate(x, angle)), # apply rotation + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ]) + return torch_transform(x) + +dataset = StreamingDataset( +data_dir, +cache_dir=str(cache_dir), +shuffle=False, +transform=[transform_fn], +sample_count=4 # Generate 4 transformed samples per input +) +``` + +
+
✅ Split datasets for train, val, test 🔗 diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 1dca6c5b..9a049363 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -13,6 +13,7 @@ import logging import os +from inspect import signature from time import time from typing import Any, Callable, Optional, Union @@ -62,6 +63,7 @@ def __init__( index_path: Optional[str] = None, force_override_state_dict: bool = False, transform: Optional[Union[Callable, list[Callable]]] = None, + sample_count: int = 1, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -89,6 +91,7 @@ def __init__( If `index_path` is a full file path, it will use that directly. force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. transform: Optional transformation function or list of functions to apply to each item in the dataset. + sample_count: Number of samples to return for each index access. """ _check_version_and_prompt_upgrade(__version__) @@ -202,12 +205,27 @@ def __init__( self.storage_options = storage_options self.session_options = session_options self.max_pre_download = max_pre_download + self.sample_count = sample_count if transform is not None: transform = transform if isinstance(transform, list) else [transform] for t in transform: if not callable(t): raise ValueError(f"Transform should be a callable. Found {t}") self.transform = transform + + # define invalid transform conditions for multisample case + invalid_transform = self.sample_count > 1 and ( + not hasattr(self, "transform") + or len(self.transform) > 1 + or "sample_idx" not in signature(self.transform[0]).parameters + ) + if invalid_transform: + logger.warning( + "Invalid transform configuration detected. " + "Either no transform, multiple transforms, or missing `sample_idx` parameter. " + "Reverting `sample_count` to 1 and returning data as-is." + ) + self.sample_count = 1 self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache @property @@ -282,7 +300,7 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) def __len__(self) -> int: - return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) + return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) * self.sample_count def set_batch_size(self, batch_size: int) -> None: self.batch_size = batch_size @@ -324,7 +342,7 @@ def __iter__(self) -> "StreamingDataset": self.worker_intervals = workers_intervals[worker_rank] # The max number of samples to return from `__next__` (in worker) - self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) + self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.sample_count # Handle restart if self._state_dict: @@ -407,7 +425,9 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) # replay the indexes for the current chunks interval = self.worker_intervals[self.worker_next_chunk_index] - current_indexes = np.arange(interval[1], interval[2]) + + # multiply the interval by the sample_count for multisample case + current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count) # re-shuffle the indexes current_indexes = self.shuffler( @@ -424,6 +444,17 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any]) self.worker_next_chunk_index += 1 def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: + # Deflate index for multisample case + if self.sample_count > 1: + if isinstance(index, int): + sample_idx = index % self.sample_count + index = index // self.sample_count + elif isinstance(index, ChunkedIndex): + sample_idx = index.index % self.sample_count + index.index = index.index // self.sample_count + else: + raise ValueError("Slices are not supported when using `sample_count > 1`.") + if self.cache is None: self.worker_env = _WorkerEnv.detect() self.cache = self._create_cache(worker_env=self.worker_env) @@ -437,10 +468,11 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any: _my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices] return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices] item = self.cache[index] + if hasattr(self, "transform"): if isinstance(self.transform, list): for transform_fn in self.transform: - item = transform_fn(item) + item = transform_fn(item) if self.sample_count == 1 else transform_fn(item, sample_idx) else: item = self.transform(item) @@ -476,7 +508,9 @@ def __next__(self) -> Any: # `next_worker_chunks_index` is the index of the chunk that we will be working on now interval = self.worker_intervals[self.worker_next_chunk_index] - current_indexes = np.arange(interval[1], interval[2]) + + # current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor) + current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count) assert self.shuffler is not None assert self.num_chunks is not None diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index ddb517e4..2c2e9c92 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -1,3 +1,4 @@ +import logging import os import sys @@ -496,3 +497,126 @@ def test_dataloader_dataset_transform_inheritance(tmpdir, shuffle): # Verify that the transform is applied correctly for i, item in enumerate(complete_data): assert item == i * 2, f"Expected {i * 2}, got {item}" + + +# Define a simple transform function +def multisample_transform_fn(x, sample_idx, *args, **kwargs): + """A simple transform function that doubles the input.""" + return x * sample_idx + + +def test_dataloader_dataset_transform_multisample(tmpdir): + """Test if the dataset's transform is applied correctly with dataloader.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, transform=multisample_transform_fn, sample_count=3 + ) + dataset_length = len(dataset) + assert dataset_length == 300 + + # ACT + dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False) + + complete_data = [] + for batch in dl: + complete_data.extend(batch) + + # ASSERT + # Verify that the multisample transform is applied correctly + for i, item in enumerate(complete_data): + if i % 3 == 0: + assert item == (i // 3) * 0, f"Expected {i * 0}, got {item}" + elif i % 3 == 1: + assert item == (i // 3) * 1, f"Expected {i * 1}, got {item}" + else: + assert item == (i // 3) * 2, f"Expected {i * 2}, got {item}" + + +# Define simple transform functions +def transform_fn_sq(x, sample_idx): + """A simple transform function that doubles the input.""" + return x * sample_idx + + +def transform_fn_add(x, sample_idx): + """A simple transform function that adds the sample_idx to the input.""" + return x + sample_idx + + +def transform_fn_no_sample_idx(x): + """A simple transform function that doubles the input.""" + return x + + +def test_dataloader_dataset_transform_invalid_config(tmpdir, caplog): + """Test if the dataset's transform is applied correctly with dataloader.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Verify that logger warning happens when transform is not given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # Verify that logger warning happens when multiple transforms are given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=False, + sample_count=4, + transform=[transform_fn_sq, transform_fn_add], + ) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # Verify that logger warning happens when sample_idx parameter is missing + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx + ) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # ACT + dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False) + + complete_data = [] + for batch in dl: + complete_data.extend(batch) + + # ASSERT + # Verify that the multisample transform is applied correctly + for i, item in enumerate(complete_data): + assert item == i, f"Expected {i}, got {item}" diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c3cd4637..f5aea6fa 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -1813,3 +1813,106 @@ def transform(self, x, *args, **kwargs): # Verify that the transform is applied correctly for i, item in enumerate(complete_data): assert item == i * 2, f"Expected {i * 2}, got {item}" + + +def test_dataset_transform_multisample(tmpdir): + """Test if the dataset transform is applied correctly.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # Define simple transform functions + def transform_fn_sq(x, sample_idx): + """A simple transform function that doubles the input.""" + return x * sample_idx + + sample_count = 3 + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, sample_count=sample_count + ) + dataset_length = len(dataset) + assert dataset_length == 300 + + # ASSERT + # Verify that the transform functions are applied correctly + for i, item in enumerate(dataset): + assert item is not None + if i % sample_count == 0: + assert item == (i // sample_count) * 0, f"Expected {(i // sample_count) * 0}, got {item}" + elif i % sample_count == 1: + assert item == (i // sample_count) * 1, f"Expected {(i // sample_count) * 1}, got {item}" + else: + assert item == (i // sample_count) * 2, f"Expected {(i // sample_count) * 2}, got {item}" + + +def test_dataset_transform_multisample_invalid_config(tmpdir, caplog): + """Test if the dataset raises an error when is_multisample is True but transform is not a list.""" + # Create a simple dataset + # Create directories for cache and data + cache_dir = os.path.join(tmpdir, "cache_dir") + data_dir = os.path.join(tmpdir, "data_dir") + os.makedirs(cache_dir) + os.makedirs(data_dir) + + # Define simple transform functions + def transform_fn_sq(x, sample_idx): + """A simple transform function that doubles the input.""" + return x * sample_idx + + def transform_fn_add(x, sample_idx): + """A simple transform function that adds the sample_idx to the input.""" + return x + sample_idx + + def transform_fn_no_sample_idx(x): + """A simple transform function that misses the sample_idx parameter.""" + return x + + # Create a dataset with 100 items, 20 items per chunk + cache = Cache(str(data_dir), chunk_size=20) + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + # ASSERT + # Verify that logger warning happens when transform is not given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # Verify that logger warning happens when multiple transforms are given + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, + cache_dir=str(cache_dir), + shuffle=False, + sample_count=4, + transform=[transform_fn_sq, transform_fn_add], + ) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100 + + # Verify that logger warning happens when sample_idx parameter is missing + with caplog.at_level(logging.WARNING): + dataset = StreamingDataset( + data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx + ) + + assert "Invalid transform configuration detected." in caplog.text + dataset_length = len(dataset) + assert dataset_length == 100