diff --git a/.gitignore b/.gitignore
index 954f6dfb..c60a7810 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,7 +3,7 @@ src
__pycache__/
*.py[cod]
*$py.class
-
+outputs/
.vscode
# C extensions
diff --git a/.gitmodules b/.gitmodules
index 21f138b4..e69de29b 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +0,0 @@
-[submodule "examples/imagenet-example"]
- path = examples/imagenet-example
- url = git@github.com:libffcv/ffcv-imagenet.git
diff --git a/.nojekyll b/.nojekyll
deleted file mode 100644
index e69de29b..00000000
diff --git a/README.md b/README.md
index 74c403a3..1b6a4b5b 100644
--- a/README.md
+++ b/README.md
@@ -1,133 +1,23 @@
-
-Fast Forward Computer Vision: train models at a fraction of the cost with accelerated data loading!
-
-
+# Fast Forward Computer Vision for Pretraining
+
[install]
-[quickstart]
-[features]
+[new features]
[docs]
-[support slack]
-[homepage]
[paper]
-
-Maintainers:
-Guillaume Leclerc,
-Andrew Ilyas and
-Logan Engstrom
-`ffcv` is a drop-in data loading system that dramatically increases data throughput in model training:
-
-- [Train an ImageNet model](#prepackaged-computer-vision-benchmarks)
-on one GPU in 35 minutes (98¢/model on AWS)
-- [Train a CIFAR-10 model](https://docs.ffcv.io/ffcv_examples/cifar10.html)
-on one GPU in 36 seconds (2¢/model on AWS)
-- Train a `$YOUR_DATASET` model `$REALLY_FAST` (for `$WAY_LESS`)
-
-Keep your training algorithm the same, just replace the data loader! Look at these speedups:
-
-
-
-`ffcv` also comes prepacked with [fast, simple code](https://github.com/libffcv/imagenet-example) for [standard vision benchmarks]((https://docs.ffcv.io/benchmarks.html)):
-
-
+This library is derived from [FFCV](https://github.com/libffcv/ffcv) to optimize the memory usage and accelerate data loading.
## Installation
-### Linux
+### Running Environment
```
-conda create -y -n ffcv python=3.9 cupy pkg-config libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge
+conda create -y -n ffcv "python>=3.9" cupy pkg-config "libjpeg-turbo>=3.0.0" "opencv=4.10.0" numba -c conda-forge
conda activate ffcv
-pip install ffcv
-```
-Troubleshooting note 1: if the above commands result in a package conflict error, try running ``conda config --env --set channel_priority flexible`` in the environment and rerunning the installation command.
-
-Troubleshooting note 2: on some systems (but rarely), you'll need to add the ``compilers`` package to the first command above.
-
-Troubleshooting note 3: courtesy of @kschuerholt, here is a [Dockerfile](https://github.com/kschuerholt/pytorch_cuda_opencv_ffcv_docker) that may help with conda-free installation
-
-### Windows
-* Install opencv4
- * Add `..../opencv/build/x64/vc15/bin` to PATH environment variable
-* Install libjpeg-turbo, download libjpeg-turbo-x.x.x-vc64.exe, not gcc64
- * Add `..../libjpeg-turbo64/bin` to PATH environment variable
-* Install pthread, download last release.zip
- * After unzip, rename Pre-build.2 folder to pthread
- * Open `pthread/include/pthread.h`, and add the code below to the top of the file.
- ```cpp
- #define HAVE_STRUCT_TIMESPEC
- ```
- * Add `..../pthread/dll` to PATH environment variable
-* Install cupy depending on your CUDA Toolkit version.
-* `pip install ffcv`
-
-## Citation
-If you use FFCV, please cite it as:
-
-```
-@inproceedings{leclerc2023ffcv,
- author = {Guillaume Leclerc and Andrew Ilyas and Logan Engstrom and Sung Min Park and Hadi Salman and Aleksander Madry},
- title = {{FFCV}: Accelerating Training by Removing Data Bottlenecks},
- year = {2023},
- booktitle = {Computer Vision and Pattern Recognition (CVPR)},
- note = {\url{https://github.com/libffcv/ffcv/}. commit xxxxxxx}
-}
-```
-(Make sure to replace xxxxxxx above with the hash of the commit used!)
-
-## Quickstart
-Accelerate *any* learning system with `ffcv`.
-First,
-convert your dataset into `ffcv` format (`ffcv` converts both indexed PyTorch datasets and
-WebDatasets):
-```python
-from ffcv.writer import DatasetWriter
-from ffcv.fields import RGBImageField, IntField
-
-# Your dataset (`torch.utils.data.Dataset`) of (image, label) pairs
-my_dataset = make_my_dataset()
-write_path = '/output/path/for/converted/ds.beton'
-
-# Pass a type for each data field
-writer = DatasetWriter(write_path, {
- # Tune options to optimize dataset size, throughput at train-time
- 'image': RGBImageField(max_resolution=256),
- 'label': IntField()
-})
-
-# Write dataset
-writer.from_indexed_dataset(my_dataset)
+conda install pytorch-cuda=11.3 torchvision -c pytorch -c nvidia
+pip install .
```
-Then replace your old loader with the `ffcv` loader at train time (in PyTorch,
-no other changes required!):
-```python
-from ffcv.loader import Loader, OrderOption
-from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
-from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder
-
-# Random resized crop
-decoder = RandomResizedCropRGBImageDecoder((224, 224))
-
-# Data decoding and augmentation
-image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage(), ToDevice(0)]
-label_pipeline = [IntDecoder(), ToTensor(), ToDevice(0)]
-
-# Pipeline for each data field
-pipelines = {
- 'image': image_pipeline,
- 'label': label_pipeline
-}
-
-# Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
-loader = Loader(write_path, batch_size=bs, num_workers=num_workers,
- order=OrderOption.RANDOM, pipelines=pipelines)
-
-# rest of training / validation proceeds identically
-for epoch in range(epochs):
- ...
-```
-[See here](https://docs.ffcv.io/basics.html) for a more detailed guide to deploying `ffcv` for your dataset.
## Prepackaged Computer Vision Benchmarks
From gridding to benchmarking to fast research iteration, there are many reasons
@@ -135,12 +25,27 @@ to want faster model training. Below we present premade codebases for training
on ImageNet and CIFAR, including both (a) extensible codebases and (b)
numerous premade training configurations.
+## Make Dataset
+We provide a script to make the dataset `examples/write_dataset.py`, which provides three mode:
+- `jpg`: The script will compress all the images to jpg format.
+- `png`: The script will compress all the images to png format. This format is too slow.
+- `raw`: The script will not compress the images.
+- `smart`: The script will compress the images larger than the `threshold`.
+- `proportion`: The script will compress a random subset of the data with size specified by the `compress_probability` argument.
+
+```
+python examples/write_dataset.py --cfg.write_mode=smart --cfg.threshold=206432 --cfg.jpeg_quality=90 \
+ --cfg.num_workers=40 --cfg.max_resolution=500 \
+ --cfg.data_dir=$IMAGENET_DIR/train \
+ --cfg.write_path=$write_path
+```
### ImageNet
We provide a self-contained script for training ImageNet fast.
Above we plot the training time versus
accuracy frontier, and the dataloading speeds, for 1-GPU ResNet-18 and 8-GPU
ResNet-50 alongside a few baselines.
+TODO:
| Link to Config | top_1 | top_5 | # Epochs | Time (mins) | Architecture | Setup |
|:---------------------------------------------------------------------------------------------------------------------------------------|--------:|--------:|-----------:|--------------:|:---------------|:---------|
@@ -167,69 +72,36 @@ potential to raise the accuracy even further). You can find the training script
here.
## Features
-
-
-Computer vision or not, FFCV can help make training faster in a variety of
-resource-constrained settings!
-Our performance guide
-has a more detailed account of the ways in which FFCV can adapt to different
-performance bottlenecks.
-
-
-- **Plug-and-play with any existing training code**: Rather than changing
- aspects of model training itself, FFCV focuses on removing *data bottlenecks*,
- which turn out to be a problem everywhere from neural network training to
- linear regression. This means that:
-
- - FFCV can be introduced into any existing training code in just a few
- lines of code (e.g., just swapping out the data loader and optionally the
- augmentation pipeline);
- - You don't have to change the model itself to make it faster (e.g., feel
- free to analyze models *without* CutMix, Dropout, momentum scheduling, etc.);
- - FFCV can speed up a lot more beyond just neural network training---in
- fact, the more data-bottlenecked the application (e.g., linear regression,
- bulk inference, etc.), the faster FFCV will make it!
-
- See our [Getting started](https://docs.ffcv.io/basics.html) guide,
- [Example walkthroughs](https://docs.ffcv.io/examples.html), and
- [Code examples](https://github.com/libffcv/ffcv/tree/main/examples)
- to see how easy it is to get started!
-- **Fast data processing without the pain**: FFCV automatically handles data
- reading, pre-fetching, caching, and transfer between devices in an extremely
- efficiently way, so that users don't have to think about it.
-- **Automatically fused-and-compiled data processing**: By either using
- [pre-written](https://docs.ffcv.io/api/transforms.html) FFCV transformations
- or
- [easily writing custom ones](https://docs.ffcv.io/ffcv_examples/custom_transforms.html),
- users can
- take advantage of FFCV's compilation and pipelining abilities, which will
- automatically fuse and compile simple Python augmentations to machine code
- using [Numba](https://numba.pydata.org), and schedule them asynchronously to avoid
- loading delays.
-- **Load data fast from RAM, SSD, or networked disk**: FFCV exposes
- user-friendly options that can be adjusted based on the resources
- available. For example, if a dataset fits into memory, FFCV can cache it
- at the OS level and ensure that multiple concurrent processes all get fast
- data access. Otherwise, FFCV can use fast process-level caching and will
- optimize data loading to minimize the underlying number of disk reads. See
- [The Bottleneck Doctor](https://docs.ffcv.io/bottleneck_doctor.html)
- guide for more information.
-- **Training multiple models per GPU**: Thanks to fully asynchronous
- thread-based data loading, you can now interleave training multiple models on
- the same GPU efficiently, without any data-loading overhead. See
- [this guide](https://docs.ffcv.io/parameter_tuning.html) for more info.
-- **Dedicated tools for image handling**: All the features above work are
- equally applicable to all sorts of machine learning models, but FFCV also
- offers some vision-specific features, such as fast JPEG encoding and decoding,
- storing datasets as mixtures of raw and compressed images to trade off I/O
- overhead and compute overhead, etc. See the
- [Working with images](https://docs.ffcv.io/working_with_images.html) guide for
- more information.
-
-# Contributors
-
-- [Guillaume Leclerc](https://github.com/GuillaumeLeclerc)
-- [Logan Engstrom](http://loganengstrom.com/)
-- [Andrew Ilyas](http://andrewilyas.com/)
-- [Sam Park](http://sungminpark.com/)
-- [Hadi Salman](http://hadisalman.com/)
+
+Compared to the original FFCV, this library has the following new features:
+
+- **crop decode**: RandomCrop and CenterCrop are now implemented to decode the crop region, which can save memory and accelerate decoding.
+
+- **cache strategy**: There is a potential issue that the OS cache will be swapped out. We use `FFCV_DEFAULT_CACHE_PROCESS` to control the cache process. The choices for the cache process are:
+ - `0`: os cache
+ - `1`: process cache
+ - `2`: Shared Memory
+
+- **lossless compression**: PNG is supported for lossless compression. We use `RGBImageField(mode='png')` to enable the lossless compression.
+
+- **few memory**: We optimize the memory usage and accelerate data loading.
+
+Comparison of throughput:
+
+| img\_size | 112 | 160 | 192 | 224 | | | | | 512 |
+|--------------|--------:|--------:|--------:|:-------:|--------:|--------:|--------:|--------:|-------:|
+| batch\_size | 512 | 512 | 512 | 128 | 256 | 512 | | | 512 |
+| num\_workers | 10 | 10 | 10 | 10 | 10 | 5 | 10 | 20 | 10 |
+| loader | | | | | | | | | |
+| ours | 23024.0 | 19396.5 | 16503.6 | 16536.1 | 16338.5 | 12369.7 | 14521.4 | 14854.6 | 4260.3 |
+| ffcv | 16853.2 | 13906.3 | 13598.4 | 12192.7 | 11960.2 | 9112.7 | 12539.4 | 12601.8 | 3577.8 |
+
+Comparison of memory usage:
+| img\_size | 112 | 160 | 192 | 224 | | | | | 512 |
+|--------------|-----:|-----:|-----:|:---:|-----:|-----:|-----:|-----:|-----:|
+| batch\_size | 512 | 512 | 512 | 128 | 256 | 512 | | | 512 |
+| num\_workers | 10 | 10 | 10 | 10 | 10 | 5 | 10 | 20 | 10 |
+| loader | | | | | | | | | |
+| ours | 9.0 | 9.8 | 11.4 | 5.8 | 7.7 | 11.4 | 11.4 | 11.4 | 34.0 |
+| ffcv | 13.4 | 14.8 | 17.7 | 7.6 | 11.0 | 17.7 | 17.7 | 17.7 | 56.6 |
+
diff --git a/examples/benchmark.py b/examples/benchmark.py
new file mode 100644
index 00000000..a8658945
--- /dev/null
+++ b/examples/benchmark.py
@@ -0,0 +1,263 @@
+import argparse
+import builtins
+import datetime
+import json
+import math
+import os
+import sys
+import time
+from pathlib import Path
+
+from ffcv.loader import Loader, OrderOption
+import gin
+import numpy as np
+import timm
+import torch.backends.cudnn as cudnn
+from PIL import Image # a trick to solve loading lib problem
+from tqdm import tqdm
+
+assert timm.__version__ >= "0.6.12" # version check
+from torchvision import datasets
+import ffcv
+
+from psutil import Process, net_io_counters
+import socket
+import json
+from os import getpid
+
+from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert
+from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder
+
+import torch
+
+IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
+IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
+
+@gin.configurable
+def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0),device='cuda'):
+ device = torch.device(device)
+ image_pipeline = [
+ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,),
+ RandomHorizontalFlip(),
+ NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
+ ToTensor(), ToTorchImage(),
+ ]
+ label_pipeline = [IntDecoder(), ToTensor(),ToDevice(device), View(-1)]
+ # Pipeline for each data field
+ pipelines = {
+ 'image': image_pipeline,
+ 'label': label_pipeline
+ }
+ return pipelines
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Data loading benchmark', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int,
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)')
+ parser.add_argument('--epochs', default=5, type=int)
+ parser.add_argument('--img_size', default=224,type=int)
+
+ # Dataset parameters
+ parser.add_argument('--data_set', default='ffcv')
+ parser.add_argument("--cache_type",type=int, default=0,)
+ parser.add_argument('--data_path', default=os.getenv("IMAGENET_DIR"), type=str,
+ help='dataset path')
+
+ parser.add_argument('--output_dir', default=None, type=str,
+ help='path where to save, empty for no saving')
+
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
+ parser.set_defaults(pin_mem=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local-rank','--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+
+ return parser
+
+
+class ramqdm(tqdm):
+ """tqdm progress bar that reports RAM usage with each update"""
+ _empty_desc = "using ? GB RAM; ? CPU ? IO"
+ _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO"
+ _GB = 10**9
+ """"""
+ def __init__(self, *args, **kwargs):
+ """Override desc and get reference to current process"""
+ if "desc" in kwargs:
+ # prepend desc to the reporter mask:
+ self._empty_desc = kwargs["desc"] + " " + self._empty_desc
+ self._desc = kwargs["desc"] + " " + self._desc
+ del kwargs["desc"]
+ else:
+ # nothing to prepend, reporter mask is at start of sentence:
+ self._empty_desc = self._empty_desc.capitalize()
+ self._desc = self._desc.capitalize()
+ super().__init__(*args, desc=self._empty_desc, **kwargs)
+ self._process = Process(getpid())
+ self.metrics = []
+ """"""
+ def update(self, n=1):
+ """Calculate RAM usage and update progress bar"""
+ rss = self._process.memory_info().rss
+ ps = self._process.cpu_percent()
+ io_counters = self._process.io_counters().read_bytes
+ # net_io = net_io_counters().bytes_recv
+ # io_counters += net_io
+
+ current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6)
+ self.set_description(current_desc)
+ self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6})
+ super().update(n)
+
+ def summary(self):
+ res = {}
+ for key in self.metrics[0].keys():
+ res[key] = np.mean([i[key] for i in self.metrics])
+ return res
+
+@gin.configurable(denylist=["args"])
+def build_dataset(args, transform_fn=SimplePipeline):
+ transform_train = transform_fn(img_size=args.img_size)
+ if args.data_set == 'IF':
+ # simple augmentation
+ dataset_train = datasets.ImageFolder(args.data_path, transform=transform_train)
+ elif args.data_set == 'cifar10':
+ dataset_train = datasets.CIFAR10(args.data_path, transform=transform_train)
+ elif args.data_set == 'ffcv':
+ order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
+ dataset_train = Loader(args.data_path, pipelines=transform_train,
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ batches_ahead=4, #cache_type=args.cache_type,
+ order=order, distributed=args.distributed,seed=args.seed,drop_last=True)
+ else:
+ raise ValueError("Wrong dataset: ", args.data_set)
+ return dataset_train
+
+def load_one_epoch(args,loader):
+ start = time.time()
+ l=ramqdm(loader,disable=args.rank>0)
+
+ for x1,y in l:
+ x1.mean()
+ torch.cuda.synchronize()
+
+ end = time.time()
+
+ if args.rank ==0:
+ res = l.summary()
+ throughput=loader.reader.num_samples/(end-start)
+ res['throughput'] = throughput
+ return res
+
+import torch
+
+def main(args):
+ init_distributed_mode(args)
+
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(', ', ',\n'))
+
+ cudnn.benchmark = True
+
+ # build dataset
+ dataset_train = build_dataset(args)
+
+ num_tasks = args.world_size
+ global_rank = args.rank
+ if args.data_set != "ffcv":
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ )
+ else:
+ data_loader_train = dataset_train
+
+ for epoch in range(args.epochs):
+ res = load_one_epoch(args,data_loader_train)
+ if res:
+ throughput = res['throughput']
+ print(f"Throughput: {throughput:.2f} samples/s for {args.data_path}.")
+ res.update(args.__dict__)
+ res['version'] = ffcv.__version__
+ res['hostname'] = socket.gethostname()
+ res['epoch'] = epoch
+ if args.output_dir:
+ with open(os.path.join(args.output_dir,"data_loading.txt"),"a") as file:
+ file.write(json.dumps(res)+"\n")
+
+
+def init_distributed_mode(args):
+ if hasattr(args,'dist_on_itp') and args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), *args, **kwargs) # print with time stamp
+
+ builtins.print = print
+
+
+
+if __name__ == '__main__':
+ parser = get_args_parser()
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/docs_examples/linear_regression.py b/examples/docs_examples/linear_regression.py
index f9a6e81c..5431448a 100644
--- a/examples/docs_examples/linear_regression.py
+++ b/examples/docs_examples/linear_regression.py
@@ -57,7 +57,7 @@ def __len__(self):
train_loader = DataLoader(dataset, batch_size=2048, num_workers=8, shuffle=True)
else:
train_loader = Loader('/tmp/linreg_data.beton', batch_size=2048,
- num_workers=8, order=OrderOption.QUASI_RANDOM, os_cache=False,
+ num_workers=8, order=OrderOption.QUASI_RANDOM, cache_type=1,
pipelines={
'covariate': [NDArrayDecoder(), ToTensor(), ToDevice(ch.device('cuda:0'))],
'label': [NDArrayDecoder(), ToTensor(), Squeeze(), ToDevice(ch.device('cuda:0'))]
diff --git a/examples/imagenet-example b/examples/imagenet-example
deleted file mode 160000
index f134cbff..00000000
--- a/examples/imagenet-example
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit f134cbfff7f590954edc5c24275444b7dd2f57f6
diff --git a/examples/profiler.py b/examples/profiler.py
new file mode 100644
index 00000000..c959baee
--- /dev/null
+++ b/examples/profiler.py
@@ -0,0 +1,155 @@
+#%%
+
+import time
+from PIL import Image# a trick to solve loading lib problem
+from ffcv.fields.rgb_image import *
+from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, ToTensor, ToTorchImage, ToDevice
+import numpy as np
+import torchvision
+
+from ffcv import Loader
+import ffcv
+import argparse
+from tqdm.auto import tqdm,trange
+import torch.nn as nn
+import torch
+from psutil import Process, net_io_counters
+import json
+from os import getpid
+
+from ffcv.transforms.ops import Convert
+
+IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
+IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
+
+class ramqdm(tqdm):
+ """tqdm progress bar that reports RAM usage with each update"""
+ _empty_desc = "using ? GB RAM; ? CPU ? IO"
+ _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO"
+ _GB = 10**9
+ """"""
+ def __init__(self, *args, **kwargs):
+ """Override desc and get reference to current process"""
+ if "desc" in kwargs:
+ # prepend desc to the reporter mask:
+ self._empty_desc = kwargs["desc"] + " " + self._empty_desc
+ self._desc = kwargs["desc"] + " " + self._desc
+ del kwargs["desc"]
+ else:
+ # nothing to prepend, reporter mask is at start of sentence:
+ self._empty_desc = self._empty_desc.capitalize()
+ self._desc = self._desc.capitalize()
+ super().__init__(*args, desc=self._empty_desc, **kwargs)
+ self._process = Process(getpid())
+ self.metrics = []
+ """"""
+ def update(self, n=1):
+ """Calculate RAM usage and update progress bar"""
+ rss = self._process.memory_info().rss
+ ps = self._process.cpu_percent()
+ io_counters = self._process.io_counters().read_bytes
+ # net_io = net_io_counters().bytes_recv
+ # io_counters += net_io
+
+ current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6)
+ self.set_description(current_desc)
+ self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6})
+ super().update(n)
+
+ def summary(self):
+ res = {}
+ for key in self.metrics[0].keys():
+ res[key] = np.mean([i[key] for i in self.metrics])
+ return res
+
+
+def load_one_epoch(args,loader):
+ start = time.time()
+ l=ramqdm(loader)
+
+ for x1,y in l:
+ pass
+ end = time.time()
+ res = l.summary()
+ try:
+ throughput=loader.reader.num_samples/(end-start)
+ except:
+ throughput=len(loader.dataset)/(end-start)
+ res['throughput'] = throughput
+ x1 = x1.float()
+ print("Mean: ", x1.mean().item(), "Std: ", x1.std().item())
+ return res
+
+def main(args):
+ if args.no_ffcv:
+ tfms = torchvision.transforms.Compose([
+ torchvision.transforms.RandomResizedCrop(args.img_size),
+ torchvision.transforms.RandomHorizontalFlip(),
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Normalize(IMAGENET_MEAN/255, IMAGENET_STD/255),
+ ])
+ dataset = torchvision.datasets.ImageFolder(args.data_path, transform=tfms)
+ loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True)
+ else:
+ pipe = {
+ 'image': [RandomResizedCropRGBImageDecoder((args.img_size,args.img_size)),
+ RandomHorizontalFlip(),
+ ToTensor(),
+ # ToDevice(torch.device('cuda')),
+ ToTorchImage(),
+ # NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
+ # Convert(torch.float16),
+ ]
+ }
+ loader = Loader(args.data_path, batch_size=args.batch_size, num_workers=args.num_workers,
+ pipelines=pipe,
+ batches_ahead=2, distributed=False,seed=0,drop_last=True)
+
+
+ # warmup
+ load_one_epoch(args,loader)
+
+ for _ in range(args.repeat):
+ res = load_one_epoch(args,loader)
+ yield res
+
+#%%
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="FFCV Profiler")
+ parser.add_argument("-r", "--repeat", type=int, default=3, help="number of samples to record one step for profile.")
+ parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size")
+ parser.add_argument("-p", "--data_path", type=str, help="data path", required=True)
+ parser.add_argument("--no_ffcv",default=False,action="store_true")
+ parser.add_argument("--num_workers", type=int, default=10, help="number of workers")
+ parser.add_argument("--exp", default=False, action="store_true", help="run experiments")
+ parser.add_argument("--img_size", type=int, default=224, help="image size")
+ parser.add_argument("--write_path", type=str, help='path to write result',default=None)
+ args = parser.parse_args()
+ if args.exp == False:
+ for res in main(args):
+ throughput = res['throughput']
+ print(f"Throughput: {throughput:.2f} samples/s for {args.data_path}.")
+ res.update(args.__dict__)
+ if args.write_path:
+ with open(args.write_path,"a") as file:
+ file.write(json.dumps(res)+"\n")
+ else:
+ data = []
+ with open(args.write_path,"a") as file:
+ for num_workers in [10,20,40]:
+ for use_ffcv in [False,True]:
+ for bs in [128,256,512]:
+ args.num_workers=num_workers
+ args.batch_size = bs
+ args.use_ffcv=use_ffcv
+ row = args.__dict__
+ for res in main(args):
+ row.update(res)
+ file.write(json.dumps(row)+"\n")
+ file.flush()
+ print(row)
+ data.append(row)
+ import pandas as pd
+ df = pd.DataFrame(data)
+ print(df)
+ exit(0)
\ No newline at end of file
diff --git a/examples/vis_loader.py b/examples/vis_loader.py
new file mode 100644
index 00000000..770f51a0
--- /dev/null
+++ b/examples/vis_loader.py
@@ -0,0 +1,47 @@
+import argparse
+import time
+from PIL import Image # a trick to solve loading lib problem
+from ffcv import Loader
+from ffcv.transforms import *
+from ffcv.fields.decoders import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder
+
+
+import numpy as np
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='FFCV Profiler')
+ parser.add_argument('data_path', type=str, default='data/imagenet', help='Path to the dataset')
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
+ parser.add_argument('--write_path', type=str, default='viz.png', help='Path to write result')
+ args = parser.parse_args()
+
+ loader = Loader(args.data_path, batch_size=args.batch_size, num_workers=10, cache_type=0, pipelines={
+ 'image':[CenterCropRGBImageDecoder((224, 224),224/256),
+ ToTensor(),
+ ToTorchImage()]
+ }, batches_ahead=0,)
+
+ print("num samples: ", loader.reader.num_samples, "fields: ", loader.reader.field_names)
+ for x,_ in loader:
+ x1 = x.float()
+ print("Mean: ", x1.mean().item(), "Std: ", x1.std().item())
+ break
+
+ print('Done')
+ num = int(np.sqrt(args.batch_size))
+ import cv2
+
+ image = np.zeros((224*num, 224*num, 3), dtype=np.uint8)
+ for i in range(num):
+ for j in range(num):
+ if i*num+j >= args.batch_size:
+ break
+ img = x[i*num+j].numpy().transpose(1,2,0)
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ image[i*224:(i+1)*224, j*224:(j+1)*224] = (img).astype(np.uint8)
+
+ if args.write_path:
+ Image.fromarray(image).save(args.write_path)
+
+
diff --git a/examples/write_dataset.py b/examples/write_dataset.py
new file mode 100644
index 00000000..a9b38bef
--- /dev/null
+++ b/examples/write_dataset.py
@@ -0,0 +1,106 @@
+"""example usage:
+export IMAGENET_DIR=/path/to/pytorch/format/imagenet/directory/
+export WRITE_DIR=/your/path/here/
+write_dataset train 500 0.50 90
+write_path=$WRITE_DIR/train500_0.5_90.ffcv
+echo "Writing ImageNet train dataset to ${write_path}"
+python examples/write_dataset.py \
+ --cfg.data_dir=$IMAGENET_DIR \
+ --cfg.write_path=$write_path \
+ --cfg.max_resolution=500 \
+ --cfg.write_mode=smart \
+ --cfg.compress_probability=0.50 \
+ --cfg.jpeg_quality=90
+"""
+from PIL import Image
+from torch.utils.data import Subset
+from ffcv.writer import DatasetWriter
+from ffcv.fields import IntField, RGBImageField
+import torchvision
+from torchvision.datasets import ImageFolder
+import torchvision.datasets as torch_datasets
+
+from argparse import ArgumentParser
+from fastargs import Section, Param
+from fastargs.validation import And, OneOf
+from fastargs.decorators import param, section
+from fastargs import get_current_config
+import cv2
+import numpy as np
+
+# hack resizer
+# def resizer(image, target_resolution):
+# if target_resolution is None:
+# return image
+# original_size = np.array([image.shape[1], image.shape[0]])
+# ratio = target_resolution / original_size.min()
+# if ratio < 1:
+# new_size = (ratio * original_size).astype(int)
+# image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA)
+# return image
+# from ffcv.fields import rgb_image
+# rgb_image.resizer = resizer
+
+Section('cfg', 'arguments to give the writer').params(
+ dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'),
+ data_dir=Param(str, 'Where to find the PyTorch dataset', required=True),
+ write_path=Param(str, 'Where to write the new dataset', required=True),
+ write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'),
+ max_resolution=Param(int, 'Max image side length. 0 any size.', required=False,default=0),
+ num_workers=Param(int, 'Number of workers to use', default=16),
+ chunk_size=Param(int, 'Chunk size for writing', default=100),
+ jpeg_quality=Param(float, 'Quality of jpeg images', default=90),
+ subset=Param(int, 'How many images to use (-1 for all)', default=-1),
+ compress_probability=Param(float, 'compress probability', default=0.5),
+ threshold=Param(int, 'threshold for smart mode to compress by jpeg', default=286432),
+)
+
+@section('cfg')
+@param('dataset')
+@param('data_dir')
+@param('write_path')
+@param('max_resolution')
+@param('num_workers')
+@param('chunk_size')
+@param('subset')
+@param('jpeg_quality')
+@param('write_mode')
+@param('compress_probability')
+@param('threshold')
+def main(dataset, data_dir, write_path, max_resolution, num_workers,
+ chunk_size, subset, jpeg_quality, write_mode,
+ compress_probability, threshold):
+ if dataset == 'imagenet':
+ my_dataset = ImageFolder(root=data_dir)
+ elif dataset == 'cifar':
+ tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
+ my_dataset = torch_datasets.CIFAR10(root=data_dir, train=True, download=True)
+ else:
+ raise ValueError('Unknown dataset')
+
+
+ if subset > 0: my_dataset = Subset(my_dataset, range(subset))
+ writer = DatasetWriter(write_path, {
+ 'image': RGBImageField(write_mode=write_mode,
+ max_resolution=None if max_resolution==0 else max_resolution,
+ compress_probability=compress_probability,
+ jpeg_quality=jpeg_quality,
+ smart_threshold=threshold),
+ 'label': IntField(),
+ }, num_workers=num_workers)
+
+ writer.from_indexed_dataset(my_dataset, chunksize=chunk_size,shuffle_indices=False)
+
+if __name__ == '__main__':
+ config = get_current_config()
+ parser = ArgumentParser()
+ config.augment_argparse(parser)
+ config.collect_argparse_args(parser)
+ config.validate(mode='stderr')
+ config.summary()
+
+ args=config.get().cfg
+ assert args.write_path.endswith('.ffcv'), 'write_path must end with .ffcv'
+ file=open(args.write_path.replace(".ffcv",".meta"), 'w')
+ file.write(str(args.__dict__))
+ main()
diff --git a/ffcv-conda.yml b/ffcv-conda.yml
index f332ea51..f3f39ca9 100644
--- a/ffcv-conda.yml
+++ b/ffcv-conda.yml
@@ -1,4 +1,4 @@
-name: ffcv19
+name: ffcv
channels:
- pytorch
- defaults
diff --git a/ffcv/.DS_Store b/ffcv/.DS_Store
deleted file mode 100644
index f1fad9b3..00000000
Binary files a/ffcv/.DS_Store and /dev/null differ
diff --git a/ffcv/benchmarks/decorator.py b/ffcv/benchmarks/decorator.py
index 4e8d75a7..9651eab3 100644
--- a/ffcv/benchmarks/decorator.py
+++ b/ffcv/benchmarks/decorator.py
@@ -1,3 +1,4 @@
+import tracemalloc
from itertools import product
from time import time
from collections import defaultdict
@@ -46,6 +47,8 @@ def run_all(runs=3, warm_up=1, pattern='*'):
for args in it_args:
# with redirect_stderr(FakeSink()):
+ # Start tracing memory allocations
+ tracemalloc.start()
if True:
benchmark: Benchmark = cls(**args)
with benchmark:
@@ -57,7 +60,9 @@ def run_all(runs=3, warm_up=1, pattern='*'):
start = time()
benchmark.run()
timings.append(time() - start)
-
+ # Stop tracing memory allocations
+ current, peak = tracemalloc.get_traced_memory()
+ tracemalloc.stop()
median_time = np.median(timings)
throughput = None
@@ -66,16 +71,13 @@ def run_all(runs=3, warm_up=1, pattern='*'):
throughput = args['n'] / median_time
unit = 'it/sec'
- if throughput < 1:
- unit = 'sec/it'
- throughput = 1 /throughput
-
- throughput = np.round(throughput * 10) / 10
results[suite_name].append({
**args,
'time': median_time,
- 'throughput': str(throughput) + ' ' + unit
+ f'throughput ({unit})': f"{throughput:.2f}",
+ 'current_memory (MB)': current / 10**6,
+ 'peak_memory (MB)': peak / 10**6,
})
it_args.close()
it_suite.close()
diff --git a/ffcv/benchmarks/suites/image_read.py b/ffcv/benchmarks/suites/image_read.py
index 89f09f46..7ce54d57 100644
--- a/ffcv/benchmarks/suites/image_read.py
+++ b/ffcv/benchmarks/suites/image_read.py
@@ -42,18 +42,21 @@ def __getitem__(self, index):
'length': [3000],
'mode': [
'raw',
- 'jpg'
+ 'jpg',
+ 'png',
],
'num_workers': [
1,
8,
- 16
+ 16,
+ 32,
],
'batch_size': [
500
],
'size': [
(32, 32), # CIFAR
+ (224,224),
(300, 500), # ImageNet
],
'compile': [
@@ -83,13 +86,12 @@ def __enter__(self):
self.handle.__enter__()
name = self.handle.name
- writer = DatasetWriter(self.length, name, {
+ writer = DatasetWriter(name, {
'index': IntField(),
'value': RGBImageField(write_mode=self.mode)
})
- with writer:
- writer.write_pytorch_dataset(self.dataset, num_workers=-1, chunksize=100)
+ writer.from_indexed_dataset(self.dataset, chunksize=100)
reader = Reader(name)
manager = OSCacheManager(reader)
diff --git a/ffcv/benchmarks/suites/jpeg_decode.py b/ffcv/benchmarks/suites/jpeg_decode.py
index 31fc7860..51c74449 100644
--- a/ffcv/benchmarks/suites/jpeg_decode.py
+++ b/ffcv/benchmarks/suites/jpeg_decode.py
@@ -14,9 +14,9 @@
@benchmark({
'n': [500],
'source_image': ['../../../test_data/pig.png'],
- 'image_width': [500, 256, 1024],
- 'quality': [50, 90],
- 'compile': [True]
+ 'image_width': [224, 500, 1024],
+ 'quality': [50, 80, 90, 95],
+ 'compile': [True],
})
class JPEGDecodeBenchmark(Benchmark):
diff --git a/ffcv/benchmarks/suites/memory_read.py b/ffcv/benchmarks/suites/memory_read.py
index e6072516..2666fa34 100644
--- a/ffcv/benchmarks/suites/memory_read.py
+++ b/ffcv/benchmarks/suites/memory_read.py
@@ -59,13 +59,12 @@ def __enter__(self):
handle = self.handle.__enter__()
name = handle.name
dataset = DummyDataset(self.num_samples, self.size_bytes)
- writer = DatasetWriter(self.num_samples, name, {
+ writer = DatasetWriter(name, {
'index': IntField(),
'value': BytesField()
- })
+ }, num_workers=-1)
- with writer:
- writer.write_pytorch_dataset(dataset, num_workers=-1, chunksize=100)
+ writer.from_indexed_dataset(dataset, chunksize=100)
reader = Reader(name)
manager = OSCacheManager(reader)
diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py
index b6420f11..b90dbde8 100644
--- a/ffcv/fields/rgb_image.py
+++ b/ffcv/fields/rgb_image.py
@@ -12,7 +12,7 @@
from ..pipeline.state import State
from ..pipeline.compiler import Compiler
from ..pipeline.allocation_query import AllocationQuery
-from ..libffcv import imdecode, memcpy, resize_crop
+from ..libffcv import *
if TYPE_CHECKING:
from ..memory_managers.base import MemoryManager
@@ -21,6 +21,7 @@
IMAGE_MODES = Dict()
IMAGE_MODES['jpg'] = 0
IMAGE_MODES['raw'] = 1
+IMAGE_MODES['png'] = 2
def encode_jpeg(numpy_image, quality):
@@ -33,6 +34,11 @@ def encode_jpeg(numpy_image, quality):
return result.reshape(-1)
+def encode_png(numpy_image):
+ # x=cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
+ result = cv2.imencode('.png', numpy_image)[1]
+ result = np.frombuffer(result, np.uint8)
+ return result.reshape(-1)
def resizer(image, target_resolution):
if target_resolution is None:
@@ -41,7 +47,7 @@ def resizer(image, target_resolution):
ratio = target_resolution / original_size.max()
if ratio < 1:
new_size = (ratio * original_size).astype(int)
- image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA)
+ image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_CUBIC)
return image
@@ -101,22 +107,24 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
consider RandomResizedCropRGBImageDecoder or CenterCropRGBImageDecoder
instead."""
raise TypeError(msg)
-
- biggest_shape = (max_height, max_width, 3)
+
+ max_shape = ((np.uint64(widths)*np.uint64(heights)*3).max(),)
my_dtype = np.dtype(' Callable:
mem_read = self.memory_read
imdecode_c = Compiler.compile(imdecode)
+ cv_imdecode_c = Compiler.compile(cv_imdecode)
jpg = IMAGE_MODES['jpg']
raw = IMAGE_MODES['raw']
+ png = IMAGE_MODES['png']
my_range = Compiler.get_iterator()
my_memcpy = Compiler.compile(memcpy)
@@ -130,8 +138,12 @@ def decode(batch_indices, destination, metadata, storage_state):
if field['mode'] == jpg:
imdecode_c(image_data, destination[dst_ix],
height, width, height, width, 0, 0, 1, 1, False, False)
- else:
+ elif field['mode'] == raw:
my_memcpy(image_data, destination[dst_ix])
+ elif field['mode'] == png:
+ cv_imdecode_c(image_data, destination[dst_ix])
+ else:
+ pass
return destination[:len(batch_indices)]
@@ -144,9 +156,14 @@ class ResizedCropRGBImageDecoder(SimpleRGBImageDecoder, metaclass=ABCMeta):
It supports both variable and constant resolution datasets.
"""
- def __init__(self, output_size):
+ def __init__(self, output_size,interpolation):
super().__init__()
self.output_size = output_size
+ self.interpolation = interpolation
+ self.use_crop_decode = True
+
+ def use_crop_decode_(self, value):
+ self.use_crop_decode = value
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
widths = self.metadata['width']
@@ -156,27 +173,36 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
self.max_height = np.uint64(heights.max())
output_shape = (self.output_size[0], self.output_size[1], 3)
my_dtype = np.dtype(' Callable:
jpg = IMAGE_MODES['jpg']
+ raw = IMAGE_MODES['raw']
+ png = IMAGE_MODES['png']
mem_read = self.memory_read
my_range = Compiler.get_iterator()
imdecode_c = Compiler.compile(imdecode)
+ cv_imdecode_c = Compiler.compile(cv_imdecode)
resize_crop_c = Compiler.compile(resize_crop)
+ imcropresizedecode_c = Compiler.compile(imcropresizedecode)
get_crop_c = Compiler.compile(self.get_crop_generator)
scale = self.scale
ratio = self.ratio
+ use_crop_decode = self.use_crop_decode
+ interpolation = self.interpolation
if isinstance(scale, tuple):
scale = np.array(scale)
if isinstance(ratio, tuple):
@@ -191,22 +217,37 @@ def decode(batch_indices, my_storage, metadata, storage_state):
height = np.uint32(field['height'])
width = np.uint32(field['width'])
+ i, j, h, w = get_crop_c(height, width, scale, ratio)
+
if field['mode'] == jpg:
temp_buffer = temp_storage[dst_ix]
- imdecode_c(image_data, temp_buffer,
- height, width, height, width, 0, 0, 1, 1, False, False)
- selected_size = 3 * height * width
- temp_buffer = temp_buffer.reshape(-1)[:selected_size]
- temp_buffer = temp_buffer.reshape(height, width, 3)
-
- else:
- temp_buffer = image_data.reshape(height, width, 3)
-
- i, j, h, w = get_crop_c(height, width, scale, ratio)
-
- resize_crop_c(temp_buffer, i, i + h, j, j + w,
+ if use_crop_decode:
+ imcropresizedecode_c(image_data, destination[dst_ix],
+ h,w,
+ i, j, interpolation)
+ else:
+ ## decode the whole image
+ imdecode_c(image_data, temp_buffer,
+ height, width, height, width, 0, 0, 1, 1, False, False)
+ ## crop and resize the image
+ selected_size = 3 * height * width
+ temp_buffer = temp_buffer.reshape(-1)[:selected_size]
+ temp_buffer = temp_buffer.reshape(height, width, 3)
+ resize_crop_c(temp_buffer, i, i + h, j, j + w,
destination[dst_ix])
-
+ elif field['mode'] == raw:
+ temp_buffer = image_data.reshape(height, width, 3)
+ resize_crop_c(temp_buffer, i, i + h, j, j + w,
+ destination[dst_ix])
+ elif field['mode'] == png:
+ temp_buffer = temp_storage[dst_ix]
+ cv_imdecode_c(image_data, temp_buffer)
+ buffer = temp_buffer[:height*width*3].reshape(height,width,3)
+ resize_crop_c(buffer, i, i + h, j, j + w,
+ destination[dst_ix])
+ else:
+ pass
+
return destination[:len(batch_indices)]
decode.is_parallel = True
return decode
@@ -231,8 +272,8 @@ class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder):
ratio : Tuple[float]
The range of potential aspect ratios that can be randomly sampled
"""
- def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3)):
- super().__init__(output_size)
+ def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3), interpolation=cv2.INTER_CUBIC):
+ super().__init__(output_size, interpolation=interpolation)
self.scale = scale
self.ratio = ratio
self.output_size = output_size
@@ -255,8 +296,8 @@ class CenterCropRGBImageDecoder(ResizedCropRGBImageDecoder):
ratio of (crop size) / (min side length)
"""
# output size: resize crop size -> output size
- def __init__(self, output_size, ratio):
- super().__init__(output_size)
+ def __init__(self, output_size, ratio, interpolation=cv2.INTER_AREA):
+ super().__init__(output_size,interpolation=interpolation)
self.scale = None
self.ratio = ratio
@@ -311,10 +352,10 @@ def get_decoder_class(self) -> Type[Operation]:
return SimpleRGBImageDecoder
@staticmethod
- def from_binary(binary: ARG_TYPE) -> Field:
+ def from_binary(binary: ARG_TYPE) -> Field: # type: ignore
return RGBImageField()
- def to_binary(self) -> ARG_TYPE:
+ def to_binary(self) -> ARG_TYPE: # type: ignore
return np.zeros(1, dtype=ARG_TYPE)[0]
def encode(self, destination, image, malloc):
@@ -335,10 +376,9 @@ def encode(self, destination, image, malloc):
image = resizer(image, self.max_resolution)
write_mode = self.write_mode
- as_jpg = None
+ ccode = None # compressed code
if write_mode == 'smart':
- as_jpg = encode_jpeg(image, self.jpeg_quality)
write_mode = 'raw'
if self.smart_threshold is not None:
if image.nbytes > self.smart_threshold:
@@ -353,13 +393,16 @@ def encode(self, destination, image, malloc):
destination['height'], destination['width'] = image.shape[:2]
if write_mode == 'jpg':
- if as_jpg is None:
- as_jpg = encode_jpeg(image, self.jpeg_quality)
- destination['data_ptr'], storage = malloc(as_jpg.nbytes)
- storage[:] = as_jpg
+ ccode = encode_jpeg(image, self.jpeg_quality)
+ destination['data_ptr'], storage = malloc(ccode.nbytes)
+ storage[:] = ccode
elif write_mode == 'raw':
image_bytes = np.ascontiguousarray(image).view('= 0:
+ os.close(self._fd)
+ self._fd = -1
+
+class SharedMemoryContext(MemoryContext):
+ def __init__(self, manager:MemoryManager):
+ self.manager = manager
+ file_name = self.manager.reader.file_name
+ name= file_name.split('/')[-1]
+
+ mmap = np.memmap(file_name, 'uint8', mode='r')
+ size= len(mmap)
+
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ print_args = {'force':True} if dist.is_initialized() else {}
+ file = os.path.join('/dev/shm',name)
+
+ if rank == 0:
+ create = not (os.path.exists(file) and filecmp.cmp(file, file_name))
+ self.mem = MasterSharedMemory(name=name, create=create, size=size)
+ print(f"[rank {rank}] copying file to shared memory",**print_args)
+ shared_mmap = np.frombuffer(self.mem.buf, dtype=np.uint8)
+ shared_mmap[:] = mmap[:]
+ if dist.is_initialized():
+ dist.barrier()
+ if rank != 0:
+ self.mem = MasterSharedMemory(name=name, create=False, size=size)
+ shared_mmap = np.frombuffer(self.mem.buf, dtype=np.uint8)
+ print(f"[rank {rank}] opening shared memory",**print_args)
+ self.mmap = shared_mmap
+
+ @property
+ def state(self):
+ return (self.mmap, self.manager.ptrs, self.manager.sizes)
+
+
+ def __enter__(self):
+ res = super().__enter__()
+ return res
+
+ def __exit__(self, __exc_type, __exc_value, __traceback):
+ # Numpy doesn't have an API to close memory maps yet
+ # The only thing one can do is flush it be since we are not
+ # Writing to it it's pointless
+ # Moreover we want to avoid opening the memmap over and over
+ # anyway.
+ return super().__exit__(__exc_type, __exc_value, __traceback)
+
+
+class SharedMemoryManager(MemoryManager):
+
+ def __init__(self, reader: 'Reader'):
+ super().__init__(reader)
+ self.context = SharedMemoryContext(self)
+
+ def schedule_epoch(self, schedule):
+ return self.context
+
+ @property
+ def state_type(self):
+ t1 = nb.uint8[::1]
+ t1.multable = False
+ t2 = nb.uint64[::1]
+ t1.mutable = False
+ return nb.types.Tuple([t1, t2, t2])
+
+ def compile_reader(self):
+ def read(address, mem_state):
+ mmap, ptrs, sizes = mem_state
+ size = sizes[np.searchsorted(ptrs, address)]
+ ref_data = mmap[address:address + size]
+ return ref_data
+
+ return Compiler.compile(read, nb.uint8[::1](nb.uint64, self.state_type))
+
+
diff --git a/ffcv/transforms/color_jitter.py b/ffcv/transforms/color_jitter.py
index a79b72fd..9270e805 100644
--- a/ffcv/transforms/color_jitter.py
+++ b/ffcv/transforms/color_jitter.py
@@ -3,8 +3,13 @@
Reference : https://github.com/pytorch/vision/blob/main/torchvision/transforms/functional_tensor.py
'''
+from typing import Callable, Optional, Tuple
+import numbers
+import math
+import random
import numpy as np
-
+from numba import njit
+import numba as nb
from dataclasses import replace
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
@@ -137,3 +142,217 @@ def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).
def declare_state_and_memory(self, previous_state):
return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype))
+
+# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/colorjitter.py
+@njit(parallel=False, fastmath=True, inline="always")
+def apply_cj(
+ im,
+ apply_bri,
+ bri_ratio,
+ apply_cont,
+ cont_ratio:np.float32,
+ apply_sat,
+ sat_ratio:np.float32,
+ apply_hue,
+ hue_factor:np.float32,
+):
+
+
+ gray = (
+ np.float32(0.2989) * im[..., 0]
+ + np.float32(0.5870) * im[..., 1]
+ + np.float32(0.1140) * im[..., 2]
+ )
+ one = np.float32(1)
+ # Brightness
+ if apply_bri:
+ im = im * bri_ratio
+
+ # Contrast
+ if apply_cont:
+ im = cont_ratio * im + (one - cont_ratio) * np.float32(gray.mean())
+
+ # Saturation
+ if apply_sat:
+ im[..., 0] = sat_ratio * im[..., 0] + (one - sat_ratio) * gray
+ im[..., 1] = sat_ratio * im[..., 1] + (one - sat_ratio) * gray
+ im[..., 2] = sat_ratio * im[..., 2] + (one - sat_ratio) * gray
+
+ # Hue
+ if apply_hue:
+ hue_factor_radians = hue_factor * 2.0 * np.pi
+ cosA = np.cos(hue_factor_radians)
+ sinA = np.sin(hue_factor_radians)
+ v1, v2, v3 = 1.0 / 3.0, np.sqrt(1.0 / 3.0), (1.0 - cosA)
+ hue_matrix = [
+ [
+ cosA + v3 / 3.0,
+ v1 * v3 - v2 * sinA,
+ v1 * v3 + v2 * sinA,
+ ],
+ [
+ v1 * v3 + v2 * sinA,
+ cosA + v1 * v3,
+ v1 * v3 - v2 * sinA,
+ ],
+ [
+ v1 * v3 - v2 * sinA,
+ v1 * v3 + v2 * sinA,
+ cosA + v1 * v3,
+ ],
+ ]
+ hue_matrix = np.array(hue_matrix, dtype=np.float32).T
+ for row in nb.prange(im.shape[0]):
+ im[row] = im[row] @ hue_matrix
+ return np.clip(im, 0, 255).astype(np.uint8)
+
+
+class RandomColorJitter(Operation):
+ """Add ColorJitter with probability jitter_prob.
+ Operates on raw arrays (not tensors).
+
+ see https://github.com/pytorch/vision/blob/28557e0cfe9113a5285330542264f03e4ba74535/torchvision/transforms/functional_tensor.py#L165
+ and https://sanje2v.wordpress.com/2021/01/11/accelerating-data-transforms/
+ Parameters
+ ----------
+ jitter_prob : float, The probability with which to apply ColorJitter.
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+ or the given [min, max]. Should be non negative numbers.
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+ or the given [min, max]. Should be non negative numbers.
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+ or the given [min, max]. Should be non negative numbers.
+ hue (float or tuple of float (min, max)): How much to jitter hue.
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+ """
+
+ def __init__(
+ self,
+ brightness=0.8,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.2,
+ p=0.5,
+ seed=None,
+ ):
+ super().__init__()
+ self.jitter_prob = p
+
+ self.brightness = self._check_input(brightness, "brightness")
+ self.contrast = self._check_input(contrast, "contrast")
+ self.saturation = self._check_input(saturation, "saturation")
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5))
+ self.seed = seed
+ assert self.jitter_prob >= 0 and self.jitter_prob <= 1
+
+ def _check_input(
+ self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
+ ):
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError(
+ f"If {name} is a single number, it must be non negative."
+ )
+ value = [center - float(value), center + float(value)]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0.0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError(f"{name} values should be between {bound}")
+ else:
+ raise TypeError(
+ f"{name} should be a single number or a list/tuple with length 2."
+ )
+
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ setattr(self, f"apply_{name}", False)
+ else:
+ setattr(self, f"apply_{name}", True)
+ return tuple(value)
+
+ def generate_code(self) -> Callable:
+ my_range = Compiler.get_iterator()
+
+ jitter_prob = self.jitter_prob
+
+ apply_bri = self.apply_brightness
+ bri = self.brightness
+
+ apply_cont = self.apply_contrast
+ cont = self.contrast
+
+ apply_sat = self.apply_saturation
+ sat = self.saturation
+
+ apply_hue = self.apply_hue
+ hue = self.hue
+
+ seed = self.seed
+ if seed is None:
+
+ def color_jitter(images, _):
+ for i in my_range(images.shape[0]):
+ if np.random.rand() > jitter_prob:
+ continue
+
+ images[i] = apply_cj(
+ images[i].astype("float32"),
+ apply_bri,
+ np.float32(np.random.uniform(bri[0], bri[1])),
+ apply_cont,
+ np.float32(np.random.uniform(cont[0], cont[1])),
+ apply_sat,
+ np.float32(np.random.uniform(sat[0], sat[1])),
+ apply_hue,
+ np.float32(np.random.uniform(hue[0], hue[1])),
+ )
+ return images
+
+ color_jitter.is_parallel = True
+ return color_jitter
+
+ def color_jitter(images, _, counter):
+
+ random.seed(seed + counter)
+ N = images.shape[0]
+ values = np.zeros(N)
+ bris = np.zeros(N)
+ conts = np.zeros(N)
+ sats = np.zeros(N)
+ hues = np.zeros(N)
+ for i in range(N):
+ values[i] = np.float32(random.uniform(0, 1))
+ bris[i] = np.float32(random.uniform(bri[0], bri[1]))
+ conts[i] = np.float32(random.uniform(cont[0], cont[1]))
+ sats[i] = np.float32(random.uniform(sat[0], sat[1]))
+ hues[i] = np.float32(random.uniform(hue[0], hue[1]))
+ for i in my_range(N):
+ if values[i] > jitter_prob:
+ continue
+ images[i] = apply_cj(
+ images[i].astype("float32"),
+ apply_bri,
+ bris[i],
+ apply_cont,
+ conts[i],
+ apply_sat,
+ sats[i],
+ apply_hue,
+ hues[i],
+ )
+ return images
+
+ color_jitter.is_parallel = True
+ color_jitter.with_counter = True
+ return color_jitter
+
+ def declare_state_and_memory(
+ self, previous_state: State
+ ) -> Tuple[State, Optional[AllocationQuery]]:
+ return (replace(previous_state, jit_mode=True), None)
diff --git a/ffcv/transforms/gaussian_blur.py b/ffcv/transforms/gaussian_blur.py
new file mode 100644
index 00000000..9aeaf30e
--- /dev/null
+++ b/ffcv/transforms/gaussian_blur.py
@@ -0,0 +1,91 @@
+# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/gaussian_blur.py
+import numpy as np
+from typing import Callable, Optional, Tuple
+from dataclasses import replace
+from ffcv.pipeline.allocation_query import AllocationQuery
+from ffcv.pipeline.operation import Operation
+from ffcv.pipeline.state import State
+from ffcv.pipeline.compiler import Compiler
+from scipy.signal import convolve2d
+
+
+def apply_blur(img, kernel_size, w):
+ pad = (kernel_size - 1) // 2
+ H, W, _ = img.shape
+ tmp = np.zeros(img.shape, dtype=np.float32)
+ for k in range(kernel_size):
+ start = max(0, pad - k)
+ stop = min(W, pad - k + W)
+ window = (img[:, start:stop] / 255) * w[k]
+ tmp[:, np.abs(stop - W) : W - start] += window
+ tmp2 = tmp + 0.0
+ for k in range(kernel_size):
+ start = max(0, pad - k)
+ stop = min(H, pad - k + H)
+ window = (tmp[start:stop] * w[k]).astype(np.uint8)
+ tmp2[np.abs(stop - H) : H - start] += window
+ return np.clip(tmp2 * 255.0, 0, 255).astype(np.uint8)
+
+
+class GaussianBlur(Operation):
+ """Blurs image with randomly chosen Gaussian blur.
+ If the image is torch Tensor, it is expected
+ to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+ Args:
+ blur_prob (float): probability to apply blurring to each input
+ kernel_size (int or sequence): Size of the Gaussian kernel.
+ sigma (float or tuple of float (min, max)): Standard deviation to be used for
+ creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
+ of float (min, max), sigma is chosen uniformly at random to lie in the
+ given range.
+ """
+
+ def __init__(self, kernel_size=5, sigma=(0.1, 2.0), p=0.5):
+ super().__init__()
+ self.blur_prob = p
+ self.kernel_size = kernel_size
+ assert sigma[1] > sigma[0]
+ self.sigmas = np.linspace(sigma[0], sigma[1], 10)
+ from scipy import signal
+
+ self.weights = np.stack(
+ [
+ signal.gaussian(kernel_size, s)
+ for s in np.linspace(sigma[0], sigma[1], 10)
+ ]
+ )
+ self.weights /= self.weights.sum(1, keepdims=True)
+
+ def generate_code(self) -> Callable:
+ my_range = Compiler.get_iterator()
+ blur_prob = self.blur_prob
+ kernel_size = self.kernel_size
+ weights = self.weights
+ apply_blur_c = Compiler.compile(apply_blur)
+
+ def blur(images, indices):
+
+ for i in my_range(images.shape[0]):
+ if np.random.rand() < blur_prob:
+ k = np.random.randint(low=0, high=10)
+ for ch in range(images.shape[-1]):
+ images[i, ..., ch] = convolve2d(
+ images[i, ..., ch],
+ np.outer(weights[k], weights[k]),
+ mode="same",
+ )
+ # images[i] = apply_blur_c(images[i], kernel_size, weights[k])
+ return images
+
+ blur.is_parallel = True
+ blur.with_indices = True
+ return blur
+
+ def declare_state_and_memory(
+ self, previous_state: State
+ ) -> Tuple[State, Optional[AllocationQuery]]:
+ return (
+ replace(previous_state, jit_mode=False),
+ None,
+ )
diff --git a/ffcv/transforms/grayscale.py b/ffcv/transforms/grayscale.py
new file mode 100644
index 00000000..e98823a1
--- /dev/null
+++ b/ffcv/transforms/grayscale.py
@@ -0,0 +1,125 @@
+"""
+# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/grayscale.py
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+
+
+from typing import Callable, Optional, Tuple
+from ffcv.pipeline.allocation_query import AllocationQuery
+from ffcv.pipeline.operation import Operation
+from ffcv.pipeline.state import State
+from ffcv.pipeline.compiler import Compiler
+from dataclasses import replace
+import numpy as np
+import random
+
+
+class RandomGrayscale(Operation):
+ """Add Gaussian Blur with probability blur_prob.
+ Operates on raw arrays (not tensors).
+
+ Parameters
+ ----------
+ blur_prob : float
+ The probability with which to flip each image in the batch
+ horizontally.
+ """
+
+ def __init__(self, p: float = 0.2, seed: int = None):
+ super().__init__()
+ self.gray_prob = p
+ self.seed = seed
+
+ def generate_code(self) -> Callable:
+ my_range = Compiler.get_iterator()
+ gray_prob = self.gray_prob
+ seed = self.seed
+
+ if seed is None:
+
+ def grayscale(images, _):
+ for i in my_range(images.shape[0]):
+ if np.random.rand() > gray_prob:
+ continue
+ images[i] = (
+ 0.2989 * images[i, ..., 0:1]
+ + 0.5870 * images[i, ..., 1:2]
+ + 0.1140 * images[i, ..., 2:3]
+ )
+ return images
+
+ grayscale.is_parallel = True
+ return grayscale
+
+ def grayscale(images, _, counter):
+ random.seed(seed + counter)
+ values = np.zeros(images.shape[0])
+ for i in range(images.shape[0]):
+ values[i] = random.uniform(0, 1)
+ for i in my_range(images.shape[0]):
+ if values[i] > gray_prob:
+ continue
+ images[i] = (
+ 0.2989 * images[i, ..., 0:1]
+ + 0.5870 * images[i, ..., 1:2]
+ + 0.1140 * images[i, ..., 2:3]
+ )
+ return images
+
+ grayscale.with_counter = True
+ grayscale.is_parallel = True
+ return grayscale
+
+ def declare_state_and_memory(
+ self, previous_state: State
+ ) -> Tuple[State, Optional[AllocationQuery]]:
+ assert previous_state.jit_mode
+ return (previous_state, None)
+
+
+class LabelGrayscale(Operation):
+ """ColorJitter info added to the labels. Should be initialized in exactly the same way as
+ :cla:`ffcv.transforms.ColorJitter`.
+ """
+
+ def __init__(self, gray_prob: float = 0.2, seed: int = None):
+ super().__init__()
+ self.gray_prob = gray_prob
+ self.seed = np.random.RandomState(seed).randint(0, 2**32 - 1)
+
+ def generate_code(self) -> Callable:
+ my_range = Compiler.get_iterator()
+ gray_prob = self.gray_prob
+ seed = self.seed
+
+ def grayscale(labels, temp_array, indices):
+ rep = ""
+ for i in indices:
+ rep += str(i)
+ local_seed = (hash(rep) + seed) % 2**31
+ temp_array[:, :-1] = labels
+ for i in my_range(temp_array.shape[0]):
+ np.random.seed(local_seed + i)
+ if np.random.rand() < gray_prob:
+ temp_array[i, -1] = 0.0
+ else:
+ temp_array[i, -1] = 1.0
+ return temp_array
+
+ grayscale.is_parallel = True
+ grayscale.with_indices = True
+
+ return grayscale
+
+ def declare_state_and_memory(
+ self, previous_state: State
+ ) -> Tuple[State, Optional[AllocationQuery]]:
+ previous_shape = previous_state.shape
+ new_shape = (previous_shape[0] + 1,)
+ return (
+ replace(previous_state, shape=new_shape, dtype=np.float32),
+ AllocationQuery(new_shape, dtype=np.float32),
+ )
\ No newline at end of file
diff --git a/ffcv/transforms/solarization.py b/ffcv/transforms/solarization.py
new file mode 100644
index 00000000..25719536
--- /dev/null
+++ b/ffcv/transforms/solarization.py
@@ -0,0 +1,119 @@
+# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/solarization.py
+"""
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+
+
+from typing import Callable, Optional, Tuple
+from ffcv.pipeline.allocation_query import AllocationQuery
+from ffcv.pipeline.operation import Operation
+from ffcv.pipeline.state import State
+from ffcv.pipeline.compiler import Compiler
+import numpy as np
+from dataclasses import replace
+import random
+
+
+class RandomSolarization(Operation):
+ """Solarize the image randomly with a given probability by inverting all pixel
+ values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
+ where ... means it can have an arbitrary number of leading dimensions.
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
+ Parameters
+ ----------
+ solarization_prob (float): probability of the image being solarized. Default value is 0.5
+ threshold (float): all pixels equal or above this value are inverted.
+ """
+
+ def __init__(
+ self, threshold: float = 128, p: float = 0.5, seed: int = None
+ ):
+ super().__init__()
+ self.sol_prob = p
+ self.threshold = threshold
+ self.seed = seed
+
+ def generate_code(self) -> Callable:
+ my_range = Compiler.get_iterator()
+ sol_prob = self.sol_prob
+ threshold = self.threshold
+ seed = self.seed
+
+ if seed is None:
+
+ def solarize(images, _):
+ for i in my_range(images.shape[0]):
+ if np.random.rand() < sol_prob:
+ mask = images[i] >= threshold
+ images[i] = np.where(mask, 255 - images[i], images[i])
+ return images
+
+ solarize.is_parallel = True
+ return solarize
+
+ def solarize(images, _, counter):
+ random.seed(seed + counter)
+ values = np.zeros(len(images))
+ for i in range(len(images)):
+ values[i] = random.uniform(0, 1)
+ for i in my_range(images.shape[0]):
+ if values[i] < sol_prob:
+ mask = images[i] >= threshold
+ images[i] = np.where(mask, 255 - images[i], images[i])
+ return images
+
+ solarize.with_counter = True
+ solarize.is_parallel = True
+ return solarize
+
+ def declare_state_and_memory(
+ self, previous_state: State
+ ) -> Tuple[State, Optional[AllocationQuery]]:
+ return (replace(previous_state, jit_mode=True), None)
+
+
+class LabelSolarization(Operation):
+ """ColorJitter info added to the labels. Should be initialized in exactly the same way as
+ :cla:`ffcv.transforms.ColorJitter`.
+ """
+
+ def __init__(
+ self, solarization_prob: float = 0.5, threshold: float = 128, seed: int = None
+ ):
+ super().__init__()
+ self.solarization_prob = solarization_prob
+ self.threshold = threshold
+ self.seed = seed
+
+ def generate_code(self) -> Callable:
+ my_range = Compiler.get_iterator()
+ solarization_prob = self.solarization_prob
+ seed = self.seed
+
+ def solarize(labels, temp_array, indices):
+ temp_array[:, :-1] = labels
+ random.seed(seed + indices)
+ for i in my_range(labels.shape[0]):
+ if random.uniform(0, 1) < solarization_prob:
+ temp_array[i, -1] = 1
+ else:
+ temp_array[i, -1] = 0
+ return temp_array
+
+ solarize.is_parallel = True
+ solarize.with_indices = True
+
+ return solarize
+
+ def declare_state_and_memory(
+ self, previous_state: State
+ ) -> Tuple[State, Optional[AllocationQuery]]:
+ previous_shape = previous_state.shape
+ new_shape = (previous_shape[0] + 1,)
+ return (
+ replace(previous_state, shape=new_shape, dtype=np.float32),
+ AllocationQuery(new_shape, dtype=np.float32),
+ )
\ No newline at end of file
diff --git a/ffcv/writer.py b/ffcv/writer.py
index 3615dd0a..20b2574c 100644
--- a/ffcv/writer.py
+++ b/ffcv/writer.py
@@ -118,6 +118,34 @@ def worker_job_indexed_dataset(input_queue, metadata_sm, metadata_type, fields,
allocations_queue.put(allocator.allocations)
+def worker_foder_dataset(input_queue, metadata_sm, metadata_type, fields,
+ allocator, done_number, allocations_queue, dataset):
+
+ metadata = np.frombuffer(metadata_sm.buf, dtype=metadata_type)
+ field_names = metadata_type.names
+
+ # This `with` block ensures that all the pages allocated have been written
+ # onto the file
+ with allocator:
+ while True:
+ chunk = input_queue.get()
+
+ if chunk is None:
+ # No more work left to do
+ break
+
+ # For each sample in the chunk
+ for dest_ix, source_ix in chunk:
+ sample = dataset[source_ix]
+ handle_sample(sample, dest_ix, field_names, metadata, allocator, fields)
+
+ # We warn the main thread of our progress
+ with done_number.get_lock():
+ done_number.value += len(chunk)
+
+ allocations_queue.put(allocator.allocations)
+
+
class DatasetWriter():
"""Writes given dataset into FFCV format (.beton).
@@ -318,6 +346,23 @@ def from_webdataset(self, shards: List[str], pipeline: Callable):
todos = zip(shards, offsets)
self._write_common(total_len, todos, worker_job_webdataset, (pipeline, ))
+ def from_image_folder(self, root: str):
+ """Read from image folder.
+ The images with the same class should be in the same folder.
+
+ Parameters
+ ----------
+ root: str
+ Path to the folder containing the images.
+ """
+ import glob
+ from torchvision.datasets.folder import DatasetFolder
+ classes, class_to_idx = DatasetFolder.find_classes(root)
+ samples = DatasetFolder.make_dataset(root, class_to_idx)
+ total_len = len(samples)
+ self._write_common(total_len, samples, worker_foder_dataset, (class_to_idx, ))
+
+
def finalize(self, allocations) :
# Writing metadata
diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp
index db4798d1..93a5e52e 100644
--- a/libffcv/libffcv.cpp
+++ b/libffcv/libffcv.cpp
@@ -7,7 +7,15 @@
#include
#include
#include
+#include
#include
+#include
+
+#include
+#include
+#include
+#include
+
#ifdef _WIN32
typedef unsigned __int32 __uint32_t;
typedef unsigned __int64 __uint64_t;
@@ -16,11 +24,42 @@
#define EXPORT
#endif
+// #define _DEBUG
+#ifdef _DEBUG
+#define DBOUT std::cout // or any other ostream
+#else
+#define DBOUT 0 && std::cout
+#endif
+
+#include // For std::pair
+
+
+int axis_to_image_boundaries(int a, int img_boundary, int mcuBlock) {
+ int img_b = img_boundary - (img_boundary % mcuBlock);
+ int delta_a = a % mcuBlock;
+ // reduce the a to align the mcu block
+ if (a > img_b) {
+ a = img_b;
+
+ } else {
+ a -= delta_a;
+ }
+ return a;
+}
+
+struct Boundaries {
+ int x;
+ int y;
+ int h;
+ int w;
+};
+
extern "C" {
// a key use to point to the tjtransform instance
static pthread_key_t key_tj_transformer;
// a key use to point to the tjdecompressor instance
static pthread_key_t key_tj_decompressor;
+ static pthread_key_t key_share_buffer;
static pthread_once_t key_once = PTHREAD_ONCE_INIT;
// will make the keys to access the tj instances
@@ -28,17 +67,39 @@ extern "C" {
{
pthread_key_create(&key_tj_decompressor, NULL);
pthread_key_create(&key_tj_transformer, NULL);
+ pthread_key_create(&key_share_buffer, NULL);
+ }
+
+ EXPORT int cv_imdecode(uint8_t* buf,
+ uint64_t buf_size,
+ int64_t flag,
+ uint8_t* output_buffer){
+ DBOUT << "imdecode called" << std::endl;
+ cv::Mat bufArray(1, buf_size, CV_8UC1, buf);
+ cv::Mat image;
+ image = cv::imdecode(bufArray, flag);
+ // Check for failure
+ if (image.empty()) {
+ std::cout << "Could not decode the image" << std::endl;
+ return -1;
+ }else{
+ DBOUT << "Image decoded" << image.rows<<","<