diff --git a/.gitignore b/.gitignore index 954f6dfb..c60a7810 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ src __pycache__/ *.py[cod] *$py.class - +outputs/ .vscode # C extensions diff --git a/.gitmodules b/.gitmodules index 21f138b4..e69de29b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "examples/imagenet-example"] - path = examples/imagenet-example - url = git@github.com:libffcv/ffcv-imagenet.git diff --git a/.nojekyll b/.nojekyll deleted file mode 100644 index e69de29b..00000000 diff --git a/README.md b/README.md index 74c403a3..1b6a4b5b 100644 --- a/README.md +++ b/README.md @@ -1,133 +1,23 @@ -

-Fast Forward Computer Vision: train models at a fraction of the cost with accelerated data loading! -

- +# Fast Forward Computer Vision for Pretraining +

[install] -[quickstart] -[features] +[new features] [docs] -[support slack] -[homepage] [paper] -
-Maintainers: -Guillaume Leclerc, -Andrew Ilyas and -Logan Engstrom

-`ffcv` is a drop-in data loading system that dramatically increases data throughput in model training: - -- [Train an ImageNet model](#prepackaged-computer-vision-benchmarks) -on one GPU in 35 minutes (98¢/model on AWS) -- [Train a CIFAR-10 model](https://docs.ffcv.io/ffcv_examples/cifar10.html) -on one GPU in 36 seconds (2¢/model on AWS) -- Train a `$YOUR_DATASET` model `$REALLY_FAST` (for `$WAY_LESS`) - -Keep your training algorithm the same, just replace the data loader! Look at these speedups: - - - -`ffcv` also comes prepacked with [fast, simple code](https://github.com/libffcv/imagenet-example) for [standard vision benchmarks]((https://docs.ffcv.io/benchmarks.html)): - - +This library is derived from [FFCV](https://github.com/libffcv/ffcv) to optimize the memory usage and accelerate data loading. ## Installation -### Linux +### Running Environment ``` -conda create -y -n ffcv python=3.9 cupy pkg-config libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge +conda create -y -n ffcv "python>=3.9" cupy pkg-config "libjpeg-turbo>=3.0.0" "opencv=4.10.0" numba -c conda-forge conda activate ffcv -pip install ffcv -``` -Troubleshooting note 1: if the above commands result in a package conflict error, try running ``conda config --env --set channel_priority flexible`` in the environment and rerunning the installation command. - -Troubleshooting note 2: on some systems (but rarely), you'll need to add the ``compilers`` package to the first command above. - -Troubleshooting note 3: courtesy of @kschuerholt, here is a [Dockerfile](https://github.com/kschuerholt/pytorch_cuda_opencv_ffcv_docker) that may help with conda-free installation - -### Windows -* Install opencv4 - * Add `..../opencv/build/x64/vc15/bin` to PATH environment variable -* Install libjpeg-turbo, download libjpeg-turbo-x.x.x-vc64.exe, not gcc64 - * Add `..../libjpeg-turbo64/bin` to PATH environment variable -* Install pthread, download last release.zip - * After unzip, rename Pre-build.2 folder to pthread - * Open `pthread/include/pthread.h`, and add the code below to the top of the file. - ```cpp - #define HAVE_STRUCT_TIMESPEC - ``` - * Add `..../pthread/dll` to PATH environment variable -* Install cupy depending on your CUDA Toolkit version. -* `pip install ffcv` - -## Citation -If you use FFCV, please cite it as: - -``` -@inproceedings{leclerc2023ffcv, - author = {Guillaume Leclerc and Andrew Ilyas and Logan Engstrom and Sung Min Park and Hadi Salman and Aleksander Madry}, - title = {{FFCV}: Accelerating Training by Removing Data Bottlenecks}, - year = {2023}, - booktitle = {Computer Vision and Pattern Recognition (CVPR)}, - note = {\url{https://github.com/libffcv/ffcv/}. commit xxxxxxx} -} -``` -(Make sure to replace xxxxxxx above with the hash of the commit used!) - -## Quickstart -Accelerate *any* learning system with `ffcv`. -First, -convert your dataset into `ffcv` format (`ffcv` converts both indexed PyTorch datasets and -WebDatasets): -```python -from ffcv.writer import DatasetWriter -from ffcv.fields import RGBImageField, IntField - -# Your dataset (`torch.utils.data.Dataset`) of (image, label) pairs -my_dataset = make_my_dataset() -write_path = '/output/path/for/converted/ds.beton' - -# Pass a type for each data field -writer = DatasetWriter(write_path, { - # Tune options to optimize dataset size, throughput at train-time - 'image': RGBImageField(max_resolution=256), - 'label': IntField() -}) - -# Write dataset -writer.from_indexed_dataset(my_dataset) +conda install pytorch-cuda=11.3 torchvision -c pytorch -c nvidia +pip install . ``` -Then replace your old loader with the `ffcv` loader at train time (in PyTorch, -no other changes required!): -```python -from ffcv.loader import Loader, OrderOption -from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout -from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder - -# Random resized crop -decoder = RandomResizedCropRGBImageDecoder((224, 224)) - -# Data decoding and augmentation -image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage(), ToDevice(0)] -label_pipeline = [IntDecoder(), ToTensor(), ToDevice(0)] - -# Pipeline for each data field -pipelines = { - 'image': image_pipeline, - 'label': label_pipeline -} - -# Replaces PyTorch data loader (`torch.utils.data.Dataloader`) -loader = Loader(write_path, batch_size=bs, num_workers=num_workers, - order=OrderOption.RANDOM, pipelines=pipelines) - -# rest of training / validation proceeds identically -for epoch in range(epochs): - ... -``` -[See here](https://docs.ffcv.io/basics.html) for a more detailed guide to deploying `ffcv` for your dataset. ## Prepackaged Computer Vision Benchmarks From gridding to benchmarking to fast research iteration, there are many reasons @@ -135,12 +25,27 @@ to want faster model training. Below we present premade codebases for training on ImageNet and CIFAR, including both (a) extensible codebases and (b) numerous premade training configurations. +## Make Dataset +We provide a script to make the dataset `examples/write_dataset.py`, which provides three mode: +- `jpg`: The script will compress all the images to jpg format. +- `png`: The script will compress all the images to png format. This format is too slow. +- `raw`: The script will not compress the images. +- `smart`: The script will compress the images larger than the `threshold`. +- `proportion`: The script will compress a random subset of the data with size specified by the `compress_probability` argument. + +``` +python examples/write_dataset.py --cfg.write_mode=smart --cfg.threshold=206432 --cfg.jpeg_quality=90 \ + --cfg.num_workers=40 --cfg.max_resolution=500 \ + --cfg.data_dir=$IMAGENET_DIR/train \ + --cfg.write_path=$write_path +``` ### ImageNet We provide a self-contained script for training ImageNet fast. Above we plot the training time versus accuracy frontier, and the dataloading speeds, for 1-GPU ResNet-18 and 8-GPU ResNet-50 alongside a few baselines. +TODO: | Link to Config | top_1 | top_5 | # Epochs | Time (mins) | Architecture | Setup | |:---------------------------------------------------------------------------------------------------------------------------------------|--------:|--------:|-----------:|--------------:|:---------------|:---------| @@ -167,69 +72,36 @@ potential to raise the accuracy even further). You can find the training script here. ## Features - - -Computer vision or not, FFCV can help make training faster in a variety of -resource-constrained settings! -Our performance guide -has a more detailed account of the ways in which FFCV can adapt to different -performance bottlenecks. - - -- **Plug-and-play with any existing training code**: Rather than changing - aspects of model training itself, FFCV focuses on removing *data bottlenecks*, - which turn out to be a problem everywhere from neural network training to - linear regression. This means that: - - - FFCV can be introduced into any existing training code in just a few - lines of code (e.g., just swapping out the data loader and optionally the - augmentation pipeline); - - You don't have to change the model itself to make it faster (e.g., feel - free to analyze models *without* CutMix, Dropout, momentum scheduling, etc.); - - FFCV can speed up a lot more beyond just neural network training---in - fact, the more data-bottlenecked the application (e.g., linear regression, - bulk inference, etc.), the faster FFCV will make it! - - See our [Getting started](https://docs.ffcv.io/basics.html) guide, - [Example walkthroughs](https://docs.ffcv.io/examples.html), and - [Code examples](https://github.com/libffcv/ffcv/tree/main/examples) - to see how easy it is to get started! -- **Fast data processing without the pain**: FFCV automatically handles data - reading, pre-fetching, caching, and transfer between devices in an extremely - efficiently way, so that users don't have to think about it. -- **Automatically fused-and-compiled data processing**: By either using - [pre-written](https://docs.ffcv.io/api/transforms.html) FFCV transformations - or - [easily writing custom ones](https://docs.ffcv.io/ffcv_examples/custom_transforms.html), - users can - take advantage of FFCV's compilation and pipelining abilities, which will - automatically fuse and compile simple Python augmentations to machine code - using [Numba](https://numba.pydata.org), and schedule them asynchronously to avoid - loading delays. -- **Load data fast from RAM, SSD, or networked disk**: FFCV exposes - user-friendly options that can be adjusted based on the resources - available. For example, if a dataset fits into memory, FFCV can cache it - at the OS level and ensure that multiple concurrent processes all get fast - data access. Otherwise, FFCV can use fast process-level caching and will - optimize data loading to minimize the underlying number of disk reads. See - [The Bottleneck Doctor](https://docs.ffcv.io/bottleneck_doctor.html) - guide for more information. -- **Training multiple models per GPU**: Thanks to fully asynchronous - thread-based data loading, you can now interleave training multiple models on - the same GPU efficiently, without any data-loading overhead. See - [this guide](https://docs.ffcv.io/parameter_tuning.html) for more info. -- **Dedicated tools for image handling**: All the features above work are - equally applicable to all sorts of machine learning models, but FFCV also - offers some vision-specific features, such as fast JPEG encoding and decoding, - storing datasets as mixtures of raw and compressed images to trade off I/O - overhead and compute overhead, etc. See the - [Working with images](https://docs.ffcv.io/working_with_images.html) guide for - more information. - -# Contributors - -- [Guillaume Leclerc](https://github.com/GuillaumeLeclerc) -- [Logan Engstrom](http://loganengstrom.com/) -- [Andrew Ilyas](http://andrewilyas.com/) -- [Sam Park](http://sungminpark.com/) -- [Hadi Salman](http://hadisalman.com/) + +Compared to the original FFCV, this library has the following new features: + +- **crop decode**: RandomCrop and CenterCrop are now implemented to decode the crop region, which can save memory and accelerate decoding. + +- **cache strategy**: There is a potential issue that the OS cache will be swapped out. We use `FFCV_DEFAULT_CACHE_PROCESS` to control the cache process. The choices for the cache process are: + - `0`: os cache + - `1`: process cache + - `2`: Shared Memory + +- **lossless compression**: PNG is supported for lossless compression. We use `RGBImageField(mode='png')` to enable the lossless compression. + +- **few memory**: We optimize the memory usage and accelerate data loading. + +Comparison of throughput: + +| img\_size | 112 | 160 | 192 | 224 | | | | | 512 | +|--------------|--------:|--------:|--------:|:-------:|--------:|--------:|--------:|--------:|-------:| +| batch\_size | 512 | 512 | 512 | 128 | 256 | 512 | | | 512 | +| num\_workers | 10 | 10 | 10 | 10 | 10 | 5 | 10 | 20 | 10 | +| loader | | | | | | | | | | +| ours | 23024.0 | 19396.5 | 16503.6 | 16536.1 | 16338.5 | 12369.7 | 14521.4 | 14854.6 | 4260.3 | +| ffcv | 16853.2 | 13906.3 | 13598.4 | 12192.7 | 11960.2 | 9112.7 | 12539.4 | 12601.8 | 3577.8 | + +Comparison of memory usage: +| img\_size | 112 | 160 | 192 | 224 | | | | | 512 | +|--------------|-----:|-----:|-----:|:---:|-----:|-----:|-----:|-----:|-----:| +| batch\_size | 512 | 512 | 512 | 128 | 256 | 512 | | | 512 | +| num\_workers | 10 | 10 | 10 | 10 | 10 | 5 | 10 | 20 | 10 | +| loader | | | | | | | | | | +| ours | 9.0 | 9.8 | 11.4 | 5.8 | 7.7 | 11.4 | 11.4 | 11.4 | 34.0 | +| ffcv | 13.4 | 14.8 | 17.7 | 7.6 | 11.0 | 17.7 | 17.7 | 17.7 | 56.6 | + diff --git a/examples/benchmark.py b/examples/benchmark.py new file mode 100644 index 00000000..a8658945 --- /dev/null +++ b/examples/benchmark.py @@ -0,0 +1,263 @@ +import argparse +import builtins +import datetime +import json +import math +import os +import sys +import time +from pathlib import Path + +from ffcv.loader import Loader, OrderOption +import gin +import numpy as np +import timm +import torch.backends.cudnn as cudnn +from PIL import Image # a trick to solve loading lib problem +from tqdm import tqdm + +assert timm.__version__ >= "0.6.12" # version check +from torchvision import datasets +import ffcv + +from psutil import Process, net_io_counters +import socket +import json +from os import getpid + +from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert +from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder + +import torch + +IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 +IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 + +@gin.configurable +def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0),device='cuda'): + device = torch.device(device) + image_pipeline = [ + RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,), + RandomHorizontalFlip(), + NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), + ToTensor(), ToTorchImage(), + ] + label_pipeline = [IntDecoder(), ToTensor(),ToDevice(device), View(-1)] + # Pipeline for each data field + pipelines = { + 'image': image_pipeline, + 'label': label_pipeline + } + return pipelines + + +def get_args_parser(): + parser = argparse.ArgumentParser('Data loading benchmark', add_help=False) + parser.add_argument('--batch_size', default=64, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)') + parser.add_argument('--epochs', default=5, type=int) + parser.add_argument('--img_size', default=224,type=int) + + # Dataset parameters + parser.add_argument('--data_set', default='ffcv') + parser.add_argument("--cache_type",type=int, default=0,) + parser.add_argument('--data_path', default=os.getenv("IMAGENET_DIR"), type=str, + help='dataset path') + + parser.add_argument('--output_dir', default=None, type=str, + help='path where to save, empty for no saving') + + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local-rank','--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + return parser + + +class ramqdm(tqdm): + """tqdm progress bar that reports RAM usage with each update""" + _empty_desc = "using ? GB RAM; ? CPU ? IO" + _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO" + _GB = 10**9 + """""" + def __init__(self, *args, **kwargs): + """Override desc and get reference to current process""" + if "desc" in kwargs: + # prepend desc to the reporter mask: + self._empty_desc = kwargs["desc"] + " " + self._empty_desc + self._desc = kwargs["desc"] + " " + self._desc + del kwargs["desc"] + else: + # nothing to prepend, reporter mask is at start of sentence: + self._empty_desc = self._empty_desc.capitalize() + self._desc = self._desc.capitalize() + super().__init__(*args, desc=self._empty_desc, **kwargs) + self._process = Process(getpid()) + self.metrics = [] + """""" + def update(self, n=1): + """Calculate RAM usage and update progress bar""" + rss = self._process.memory_info().rss + ps = self._process.cpu_percent() + io_counters = self._process.io_counters().read_bytes + # net_io = net_io_counters().bytes_recv + # io_counters += net_io + + current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6) + self.set_description(current_desc) + self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6}) + super().update(n) + + def summary(self): + res = {} + for key in self.metrics[0].keys(): + res[key] = np.mean([i[key] for i in self.metrics]) + return res + +@gin.configurable(denylist=["args"]) +def build_dataset(args, transform_fn=SimplePipeline): + transform_train = transform_fn(img_size=args.img_size) + if args.data_set == 'IF': + # simple augmentation + dataset_train = datasets.ImageFolder(args.data_path, transform=transform_train) + elif args.data_set == 'cifar10': + dataset_train = datasets.CIFAR10(args.data_path, transform=transform_train) + elif args.data_set == 'ffcv': + order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM + dataset_train = Loader(args.data_path, pipelines=transform_train, + batch_size=args.batch_size, num_workers=args.num_workers, + batches_ahead=4, #cache_type=args.cache_type, + order=order, distributed=args.distributed,seed=args.seed,drop_last=True) + else: + raise ValueError("Wrong dataset: ", args.data_set) + return dataset_train + +def load_one_epoch(args,loader): + start = time.time() + l=ramqdm(loader,disable=args.rank>0) + + for x1,y in l: + x1.mean() + torch.cuda.synchronize() + + end = time.time() + + if args.rank ==0: + res = l.summary() + throughput=loader.reader.num_samples/(end-start) + res['throughput'] = throughput + return res + +import torch + +def main(args): + init_distributed_mode(args) + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + cudnn.benchmark = True + + # build dataset + dataset_train = build_dataset(args) + + num_tasks = args.world_size + global_rank = args.rank + if args.data_set != "ffcv": + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + else: + data_loader_train = dataset_train + + for epoch in range(args.epochs): + res = load_one_epoch(args,data_loader_train) + if res: + throughput = res['throughput'] + print(f"Throughput: {throughput:.2f} samples/s for {args.data_path}.") + res.update(args.__dict__) + res['version'] = ffcv.__version__ + res['hostname'] = socket.gethostname() + res['epoch'] = epoch + if args.output_dir: + with open(os.path.join(args.output_dir,"data_loading.txt"),"a") as file: + file.write(json.dumps(res)+"\n") + + +def init_distributed_mode(args): + if hasattr(args,'dist_on_itp') and args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), *args, **kwargs) # print with time stamp + + builtins.print = print + + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + main(args) diff --git a/examples/docs_examples/linear_regression.py b/examples/docs_examples/linear_regression.py index f9a6e81c..5431448a 100644 --- a/examples/docs_examples/linear_regression.py +++ b/examples/docs_examples/linear_regression.py @@ -57,7 +57,7 @@ def __len__(self): train_loader = DataLoader(dataset, batch_size=2048, num_workers=8, shuffle=True) else: train_loader = Loader('/tmp/linreg_data.beton', batch_size=2048, - num_workers=8, order=OrderOption.QUASI_RANDOM, os_cache=False, + num_workers=8, order=OrderOption.QUASI_RANDOM, cache_type=1, pipelines={ 'covariate': [NDArrayDecoder(), ToTensor(), ToDevice(ch.device('cuda:0'))], 'label': [NDArrayDecoder(), ToTensor(), Squeeze(), ToDevice(ch.device('cuda:0'))] diff --git a/examples/imagenet-example b/examples/imagenet-example deleted file mode 160000 index f134cbff..00000000 --- a/examples/imagenet-example +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f134cbfff7f590954edc5c24275444b7dd2f57f6 diff --git a/examples/profiler.py b/examples/profiler.py new file mode 100644 index 00000000..c959baee --- /dev/null +++ b/examples/profiler.py @@ -0,0 +1,155 @@ +#%% + +import time +from PIL import Image# a trick to solve loading lib problem +from ffcv.fields.rgb_image import * +from ffcv.transforms import RandomHorizontalFlip, NormalizeImage, ToTensor, ToTorchImage, ToDevice +import numpy as np +import torchvision + +from ffcv import Loader +import ffcv +import argparse +from tqdm.auto import tqdm,trange +import torch.nn as nn +import torch +from psutil import Process, net_io_counters +import json +from os import getpid + +from ffcv.transforms.ops import Convert + +IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 +IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 + +class ramqdm(tqdm): + """tqdm progress bar that reports RAM usage with each update""" + _empty_desc = "using ? GB RAM; ? CPU ? IO" + _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO" + _GB = 10**9 + """""" + def __init__(self, *args, **kwargs): + """Override desc and get reference to current process""" + if "desc" in kwargs: + # prepend desc to the reporter mask: + self._empty_desc = kwargs["desc"] + " " + self._empty_desc + self._desc = kwargs["desc"] + " " + self._desc + del kwargs["desc"] + else: + # nothing to prepend, reporter mask is at start of sentence: + self._empty_desc = self._empty_desc.capitalize() + self._desc = self._desc.capitalize() + super().__init__(*args, desc=self._empty_desc, **kwargs) + self._process = Process(getpid()) + self.metrics = [] + """""" + def update(self, n=1): + """Calculate RAM usage and update progress bar""" + rss = self._process.memory_info().rss + ps = self._process.cpu_percent() + io_counters = self._process.io_counters().read_bytes + # net_io = net_io_counters().bytes_recv + # io_counters += net_io + + current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6) + self.set_description(current_desc) + self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6}) + super().update(n) + + def summary(self): + res = {} + for key in self.metrics[0].keys(): + res[key] = np.mean([i[key] for i in self.metrics]) + return res + + +def load_one_epoch(args,loader): + start = time.time() + l=ramqdm(loader) + + for x1,y in l: + pass + end = time.time() + res = l.summary() + try: + throughput=loader.reader.num_samples/(end-start) + except: + throughput=len(loader.dataset)/(end-start) + res['throughput'] = throughput + x1 = x1.float() + print("Mean: ", x1.mean().item(), "Std: ", x1.std().item()) + return res + +def main(args): + if args.no_ffcv: + tfms = torchvision.transforms.Compose([ + torchvision.transforms.RandomResizedCrop(args.img_size), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(IMAGENET_MEAN/255, IMAGENET_STD/255), + ]) + dataset = torchvision.datasets.ImageFolder(args.data_path, transform=tfms) + loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True) + else: + pipe = { + 'image': [RandomResizedCropRGBImageDecoder((args.img_size,args.img_size)), + RandomHorizontalFlip(), + ToTensor(), + # ToDevice(torch.device('cuda')), + ToTorchImage(), + # NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16), + # Convert(torch.float16), + ] + } + loader = Loader(args.data_path, batch_size=args.batch_size, num_workers=args.num_workers, + pipelines=pipe, + batches_ahead=2, distributed=False,seed=0,drop_last=True) + + + # warmup + load_one_epoch(args,loader) + + for _ in range(args.repeat): + res = load_one_epoch(args,loader) + yield res + +#%% +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="FFCV Profiler") + parser.add_argument("-r", "--repeat", type=int, default=3, help="number of samples to record one step for profile.") + parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size") + parser.add_argument("-p", "--data_path", type=str, help="data path", required=True) + parser.add_argument("--no_ffcv",default=False,action="store_true") + parser.add_argument("--num_workers", type=int, default=10, help="number of workers") + parser.add_argument("--exp", default=False, action="store_true", help="run experiments") + parser.add_argument("--img_size", type=int, default=224, help="image size") + parser.add_argument("--write_path", type=str, help='path to write result',default=None) + args = parser.parse_args() + if args.exp == False: + for res in main(args): + throughput = res['throughput'] + print(f"Throughput: {throughput:.2f} samples/s for {args.data_path}.") + res.update(args.__dict__) + if args.write_path: + with open(args.write_path,"a") as file: + file.write(json.dumps(res)+"\n") + else: + data = [] + with open(args.write_path,"a") as file: + for num_workers in [10,20,40]: + for use_ffcv in [False,True]: + for bs in [128,256,512]: + args.num_workers=num_workers + args.batch_size = bs + args.use_ffcv=use_ffcv + row = args.__dict__ + for res in main(args): + row.update(res) + file.write(json.dumps(row)+"\n") + file.flush() + print(row) + data.append(row) + import pandas as pd + df = pd.DataFrame(data) + print(df) + exit(0) \ No newline at end of file diff --git a/examples/vis_loader.py b/examples/vis_loader.py new file mode 100644 index 00000000..770f51a0 --- /dev/null +++ b/examples/vis_loader.py @@ -0,0 +1,47 @@ +import argparse +import time +from PIL import Image # a trick to solve loading lib problem +from ffcv import Loader +from ffcv.transforms import * +from ffcv.fields.decoders import CenterCropRGBImageDecoder, RandomResizedCropRGBImageDecoder + + +import numpy as np + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='FFCV Profiler') + parser.add_argument('data_path', type=str, default='data/imagenet', help='Path to the dataset') + parser.add_argument('--batch_size', type=int, default=16, help='Batch size') + parser.add_argument('--write_path', type=str, default='viz.png', help='Path to write result') + args = parser.parse_args() + + loader = Loader(args.data_path, batch_size=args.batch_size, num_workers=10, cache_type=0, pipelines={ + 'image':[CenterCropRGBImageDecoder((224, 224),224/256), + ToTensor(), + ToTorchImage()] + }, batches_ahead=0,) + + print("num samples: ", loader.reader.num_samples, "fields: ", loader.reader.field_names) + for x,_ in loader: + x1 = x.float() + print("Mean: ", x1.mean().item(), "Std: ", x1.std().item()) + break + + print('Done') + num = int(np.sqrt(args.batch_size)) + import cv2 + + image = np.zeros((224*num, 224*num, 3), dtype=np.uint8) + for i in range(num): + for j in range(num): + if i*num+j >= args.batch_size: + break + img = x[i*num+j].numpy().transpose(1,2,0) + # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + image[i*224:(i+1)*224, j*224:(j+1)*224] = (img).astype(np.uint8) + + if args.write_path: + Image.fromarray(image).save(args.write_path) + + diff --git a/examples/write_dataset.py b/examples/write_dataset.py new file mode 100644 index 00000000..a9b38bef --- /dev/null +++ b/examples/write_dataset.py @@ -0,0 +1,106 @@ +"""example usage: +export IMAGENET_DIR=/path/to/pytorch/format/imagenet/directory/ +export WRITE_DIR=/your/path/here/ +write_dataset train 500 0.50 90 +write_path=$WRITE_DIR/train500_0.5_90.ffcv +echo "Writing ImageNet train dataset to ${write_path}" +python examples/write_dataset.py \ + --cfg.data_dir=$IMAGENET_DIR \ + --cfg.write_path=$write_path \ + --cfg.max_resolution=500 \ + --cfg.write_mode=smart \ + --cfg.compress_probability=0.50 \ + --cfg.jpeg_quality=90 +""" +from PIL import Image +from torch.utils.data import Subset +from ffcv.writer import DatasetWriter +from ffcv.fields import IntField, RGBImageField +import torchvision +from torchvision.datasets import ImageFolder +import torchvision.datasets as torch_datasets + +from argparse import ArgumentParser +from fastargs import Section, Param +from fastargs.validation import And, OneOf +from fastargs.decorators import param, section +from fastargs import get_current_config +import cv2 +import numpy as np + +# hack resizer +# def resizer(image, target_resolution): +# if target_resolution is None: +# return image +# original_size = np.array([image.shape[1], image.shape[0]]) +# ratio = target_resolution / original_size.min() +# if ratio < 1: +# new_size = (ratio * original_size).astype(int) +# image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA) +# return image +# from ffcv.fields import rgb_image +# rgb_image.resizer = resizer + +Section('cfg', 'arguments to give the writer').params( + dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'), + data_dir=Param(str, 'Where to find the PyTorch dataset', required=True), + write_path=Param(str, 'Where to write the new dataset', required=True), + write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'), + max_resolution=Param(int, 'Max image side length. 0 any size.', required=False,default=0), + num_workers=Param(int, 'Number of workers to use', default=16), + chunk_size=Param(int, 'Chunk size for writing', default=100), + jpeg_quality=Param(float, 'Quality of jpeg images', default=90), + subset=Param(int, 'How many images to use (-1 for all)', default=-1), + compress_probability=Param(float, 'compress probability', default=0.5), + threshold=Param(int, 'threshold for smart mode to compress by jpeg', default=286432), +) + +@section('cfg') +@param('dataset') +@param('data_dir') +@param('write_path') +@param('max_resolution') +@param('num_workers') +@param('chunk_size') +@param('subset') +@param('jpeg_quality') +@param('write_mode') +@param('compress_probability') +@param('threshold') +def main(dataset, data_dir, write_path, max_resolution, num_workers, + chunk_size, subset, jpeg_quality, write_mode, + compress_probability, threshold): + if dataset == 'imagenet': + my_dataset = ImageFolder(root=data_dir) + elif dataset == 'cifar': + tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) + my_dataset = torch_datasets.CIFAR10(root=data_dir, train=True, download=True) + else: + raise ValueError('Unknown dataset') + + + if subset > 0: my_dataset = Subset(my_dataset, range(subset)) + writer = DatasetWriter(write_path, { + 'image': RGBImageField(write_mode=write_mode, + max_resolution=None if max_resolution==0 else max_resolution, + compress_probability=compress_probability, + jpeg_quality=jpeg_quality, + smart_threshold=threshold), + 'label': IntField(), + }, num_workers=num_workers) + + writer.from_indexed_dataset(my_dataset, chunksize=chunk_size,shuffle_indices=False) + +if __name__ == '__main__': + config = get_current_config() + parser = ArgumentParser() + config.augment_argparse(parser) + config.collect_argparse_args(parser) + config.validate(mode='stderr') + config.summary() + + args=config.get().cfg + assert args.write_path.endswith('.ffcv'), 'write_path must end with .ffcv' + file=open(args.write_path.replace(".ffcv",".meta"), 'w') + file.write(str(args.__dict__)) + main() diff --git a/ffcv-conda.yml b/ffcv-conda.yml index f332ea51..f3f39ca9 100644 --- a/ffcv-conda.yml +++ b/ffcv-conda.yml @@ -1,4 +1,4 @@ -name: ffcv19 +name: ffcv channels: - pytorch - defaults diff --git a/ffcv/.DS_Store b/ffcv/.DS_Store deleted file mode 100644 index f1fad9b3..00000000 Binary files a/ffcv/.DS_Store and /dev/null differ diff --git a/ffcv/benchmarks/decorator.py b/ffcv/benchmarks/decorator.py index 4e8d75a7..9651eab3 100644 --- a/ffcv/benchmarks/decorator.py +++ b/ffcv/benchmarks/decorator.py @@ -1,3 +1,4 @@ +import tracemalloc from itertools import product from time import time from collections import defaultdict @@ -46,6 +47,8 @@ def run_all(runs=3, warm_up=1, pattern='*'): for args in it_args: # with redirect_stderr(FakeSink()): + # Start tracing memory allocations + tracemalloc.start() if True: benchmark: Benchmark = cls(**args) with benchmark: @@ -57,7 +60,9 @@ def run_all(runs=3, warm_up=1, pattern='*'): start = time() benchmark.run() timings.append(time() - start) - + # Stop tracing memory allocations + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() median_time = np.median(timings) throughput = None @@ -66,16 +71,13 @@ def run_all(runs=3, warm_up=1, pattern='*'): throughput = args['n'] / median_time unit = 'it/sec' - if throughput < 1: - unit = 'sec/it' - throughput = 1 /throughput - - throughput = np.round(throughput * 10) / 10 results[suite_name].append({ **args, 'time': median_time, - 'throughput': str(throughput) + ' ' + unit + f'throughput ({unit})': f"{throughput:.2f}", + 'current_memory (MB)': current / 10**6, + 'peak_memory (MB)': peak / 10**6, }) it_args.close() it_suite.close() diff --git a/ffcv/benchmarks/suites/image_read.py b/ffcv/benchmarks/suites/image_read.py index 89f09f46..7ce54d57 100644 --- a/ffcv/benchmarks/suites/image_read.py +++ b/ffcv/benchmarks/suites/image_read.py @@ -42,18 +42,21 @@ def __getitem__(self, index): 'length': [3000], 'mode': [ 'raw', - 'jpg' + 'jpg', + 'png', ], 'num_workers': [ 1, 8, - 16 + 16, + 32, ], 'batch_size': [ 500 ], 'size': [ (32, 32), # CIFAR + (224,224), (300, 500), # ImageNet ], 'compile': [ @@ -83,13 +86,12 @@ def __enter__(self): self.handle.__enter__() name = self.handle.name - writer = DatasetWriter(self.length, name, { + writer = DatasetWriter(name, { 'index': IntField(), 'value': RGBImageField(write_mode=self.mode) }) - with writer: - writer.write_pytorch_dataset(self.dataset, num_workers=-1, chunksize=100) + writer.from_indexed_dataset(self.dataset, chunksize=100) reader = Reader(name) manager = OSCacheManager(reader) diff --git a/ffcv/benchmarks/suites/jpeg_decode.py b/ffcv/benchmarks/suites/jpeg_decode.py index 31fc7860..51c74449 100644 --- a/ffcv/benchmarks/suites/jpeg_decode.py +++ b/ffcv/benchmarks/suites/jpeg_decode.py @@ -14,9 +14,9 @@ @benchmark({ 'n': [500], 'source_image': ['../../../test_data/pig.png'], - 'image_width': [500, 256, 1024], - 'quality': [50, 90], - 'compile': [True] + 'image_width': [224, 500, 1024], + 'quality': [50, 80, 90, 95], + 'compile': [True], }) class JPEGDecodeBenchmark(Benchmark): diff --git a/ffcv/benchmarks/suites/memory_read.py b/ffcv/benchmarks/suites/memory_read.py index e6072516..2666fa34 100644 --- a/ffcv/benchmarks/suites/memory_read.py +++ b/ffcv/benchmarks/suites/memory_read.py @@ -59,13 +59,12 @@ def __enter__(self): handle = self.handle.__enter__() name = handle.name dataset = DummyDataset(self.num_samples, self.size_bytes) - writer = DatasetWriter(self.num_samples, name, { + writer = DatasetWriter(name, { 'index': IntField(), 'value': BytesField() - }) + }, num_workers=-1) - with writer: - writer.write_pytorch_dataset(dataset, num_workers=-1, chunksize=100) + writer.from_indexed_dataset(dataset, chunksize=100) reader = Reader(name) manager = OSCacheManager(reader) diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index b6420f11..b90dbde8 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -12,7 +12,7 @@ from ..pipeline.state import State from ..pipeline.compiler import Compiler from ..pipeline.allocation_query import AllocationQuery -from ..libffcv import imdecode, memcpy, resize_crop +from ..libffcv import * if TYPE_CHECKING: from ..memory_managers.base import MemoryManager @@ -21,6 +21,7 @@ IMAGE_MODES = Dict() IMAGE_MODES['jpg'] = 0 IMAGE_MODES['raw'] = 1 +IMAGE_MODES['png'] = 2 def encode_jpeg(numpy_image, quality): @@ -33,6 +34,11 @@ def encode_jpeg(numpy_image, quality): return result.reshape(-1) +def encode_png(numpy_image): + # x=cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) + result = cv2.imencode('.png', numpy_image)[1] + result = np.frombuffer(result, np.uint8) + return result.reshape(-1) def resizer(image, target_resolution): if target_resolution is None: @@ -41,7 +47,7 @@ def resizer(image, target_resolution): ratio = target_resolution / original_size.max() if ratio < 1: new_size = (ratio * original_size).astype(int) - image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA) + image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_CUBIC) return image @@ -101,22 +107,24 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca consider RandomResizedCropRGBImageDecoder or CenterCropRGBImageDecoder instead.""" raise TypeError(msg) - - biggest_shape = (max_height, max_width, 3) + + max_shape = ((np.uint64(widths)*np.uint64(heights)*3).max(),) my_dtype = np.dtype(' Callable: mem_read = self.memory_read imdecode_c = Compiler.compile(imdecode) + cv_imdecode_c = Compiler.compile(cv_imdecode) jpg = IMAGE_MODES['jpg'] raw = IMAGE_MODES['raw'] + png = IMAGE_MODES['png'] my_range = Compiler.get_iterator() my_memcpy = Compiler.compile(memcpy) @@ -130,8 +138,12 @@ def decode(batch_indices, destination, metadata, storage_state): if field['mode'] == jpg: imdecode_c(image_data, destination[dst_ix], height, width, height, width, 0, 0, 1, 1, False, False) - else: + elif field['mode'] == raw: my_memcpy(image_data, destination[dst_ix]) + elif field['mode'] == png: + cv_imdecode_c(image_data, destination[dst_ix]) + else: + pass return destination[:len(batch_indices)] @@ -144,9 +156,14 @@ class ResizedCropRGBImageDecoder(SimpleRGBImageDecoder, metaclass=ABCMeta): It supports both variable and constant resolution datasets. """ - def __init__(self, output_size): + def __init__(self, output_size,interpolation): super().__init__() self.output_size = output_size + self.interpolation = interpolation + self.use_crop_decode = True + + def use_crop_decode_(self, value): + self.use_crop_decode = value def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]: widths = self.metadata['width'] @@ -156,27 +173,36 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca self.max_height = np.uint64(heights.max()) output_shape = (self.output_size[0], self.output_size[1], 3) my_dtype = np.dtype(' Callable: jpg = IMAGE_MODES['jpg'] + raw = IMAGE_MODES['raw'] + png = IMAGE_MODES['png'] mem_read = self.memory_read my_range = Compiler.get_iterator() imdecode_c = Compiler.compile(imdecode) + cv_imdecode_c = Compiler.compile(cv_imdecode) resize_crop_c = Compiler.compile(resize_crop) + imcropresizedecode_c = Compiler.compile(imcropresizedecode) get_crop_c = Compiler.compile(self.get_crop_generator) scale = self.scale ratio = self.ratio + use_crop_decode = self.use_crop_decode + interpolation = self.interpolation if isinstance(scale, tuple): scale = np.array(scale) if isinstance(ratio, tuple): @@ -191,22 +217,37 @@ def decode(batch_indices, my_storage, metadata, storage_state): height = np.uint32(field['height']) width = np.uint32(field['width']) + i, j, h, w = get_crop_c(height, width, scale, ratio) + if field['mode'] == jpg: temp_buffer = temp_storage[dst_ix] - imdecode_c(image_data, temp_buffer, - height, width, height, width, 0, 0, 1, 1, False, False) - selected_size = 3 * height * width - temp_buffer = temp_buffer.reshape(-1)[:selected_size] - temp_buffer = temp_buffer.reshape(height, width, 3) - - else: - temp_buffer = image_data.reshape(height, width, 3) - - i, j, h, w = get_crop_c(height, width, scale, ratio) - - resize_crop_c(temp_buffer, i, i + h, j, j + w, + if use_crop_decode: + imcropresizedecode_c(image_data, destination[dst_ix], + h,w, + i, j, interpolation) + else: + ## decode the whole image + imdecode_c(image_data, temp_buffer, + height, width, height, width, 0, 0, 1, 1, False, False) + ## crop and resize the image + selected_size = 3 * height * width + temp_buffer = temp_buffer.reshape(-1)[:selected_size] + temp_buffer = temp_buffer.reshape(height, width, 3) + resize_crop_c(temp_buffer, i, i + h, j, j + w, destination[dst_ix]) - + elif field['mode'] == raw: + temp_buffer = image_data.reshape(height, width, 3) + resize_crop_c(temp_buffer, i, i + h, j, j + w, + destination[dst_ix]) + elif field['mode'] == png: + temp_buffer = temp_storage[dst_ix] + cv_imdecode_c(image_data, temp_buffer) + buffer = temp_buffer[:height*width*3].reshape(height,width,3) + resize_crop_c(buffer, i, i + h, j, j + w, + destination[dst_ix]) + else: + pass + return destination[:len(batch_indices)] decode.is_parallel = True return decode @@ -231,8 +272,8 @@ class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder): ratio : Tuple[float] The range of potential aspect ratios that can be randomly sampled """ - def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3)): - super().__init__(output_size) + def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3), interpolation=cv2.INTER_CUBIC): + super().__init__(output_size, interpolation=interpolation) self.scale = scale self.ratio = ratio self.output_size = output_size @@ -255,8 +296,8 @@ class CenterCropRGBImageDecoder(ResizedCropRGBImageDecoder): ratio of (crop size) / (min side length) """ # output size: resize crop size -> output size - def __init__(self, output_size, ratio): - super().__init__(output_size) + def __init__(self, output_size, ratio, interpolation=cv2.INTER_AREA): + super().__init__(output_size,interpolation=interpolation) self.scale = None self.ratio = ratio @@ -311,10 +352,10 @@ def get_decoder_class(self) -> Type[Operation]: return SimpleRGBImageDecoder @staticmethod - def from_binary(binary: ARG_TYPE) -> Field: + def from_binary(binary: ARG_TYPE) -> Field: # type: ignore return RGBImageField() - def to_binary(self) -> ARG_TYPE: + def to_binary(self) -> ARG_TYPE: # type: ignore return np.zeros(1, dtype=ARG_TYPE)[0] def encode(self, destination, image, malloc): @@ -335,10 +376,9 @@ def encode(self, destination, image, malloc): image = resizer(image, self.max_resolution) write_mode = self.write_mode - as_jpg = None + ccode = None # compressed code if write_mode == 'smart': - as_jpg = encode_jpeg(image, self.jpeg_quality) write_mode = 'raw' if self.smart_threshold is not None: if image.nbytes > self.smart_threshold: @@ -353,13 +393,16 @@ def encode(self, destination, image, malloc): destination['height'], destination['width'] = image.shape[:2] if write_mode == 'jpg': - if as_jpg is None: - as_jpg = encode_jpeg(image, self.jpeg_quality) - destination['data_ptr'], storage = malloc(as_jpg.nbytes) - storage[:] = as_jpg + ccode = encode_jpeg(image, self.jpeg_quality) + destination['data_ptr'], storage = malloc(ccode.nbytes) + storage[:] = ccode elif write_mode == 'raw': image_bytes = np.ascontiguousarray(image).view('= 0: + os.close(self._fd) + self._fd = -1 + +class SharedMemoryContext(MemoryContext): + def __init__(self, manager:MemoryManager): + self.manager = manager + file_name = self.manager.reader.file_name + name= file_name.split('/')[-1] + + mmap = np.memmap(file_name, 'uint8', mode='r') + size= len(mmap) + + rank = dist.get_rank() if dist.is_initialized() else 0 + print_args = {'force':True} if dist.is_initialized() else {} + file = os.path.join('/dev/shm',name) + + if rank == 0: + create = not (os.path.exists(file) and filecmp.cmp(file, file_name)) + self.mem = MasterSharedMemory(name=name, create=create, size=size) + print(f"[rank {rank}] copying file to shared memory",**print_args) + shared_mmap = np.frombuffer(self.mem.buf, dtype=np.uint8) + shared_mmap[:] = mmap[:] + if dist.is_initialized(): + dist.barrier() + if rank != 0: + self.mem = MasterSharedMemory(name=name, create=False, size=size) + shared_mmap = np.frombuffer(self.mem.buf, dtype=np.uint8) + print(f"[rank {rank}] opening shared memory",**print_args) + self.mmap = shared_mmap + + @property + def state(self): + return (self.mmap, self.manager.ptrs, self.manager.sizes) + + + def __enter__(self): + res = super().__enter__() + return res + + def __exit__(self, __exc_type, __exc_value, __traceback): + # Numpy doesn't have an API to close memory maps yet + # The only thing one can do is flush it be since we are not + # Writing to it it's pointless + # Moreover we want to avoid opening the memmap over and over + # anyway. + return super().__exit__(__exc_type, __exc_value, __traceback) + + +class SharedMemoryManager(MemoryManager): + + def __init__(self, reader: 'Reader'): + super().__init__(reader) + self.context = SharedMemoryContext(self) + + def schedule_epoch(self, schedule): + return self.context + + @property + def state_type(self): + t1 = nb.uint8[::1] + t1.multable = False + t2 = nb.uint64[::1] + t1.mutable = False + return nb.types.Tuple([t1, t2, t2]) + + def compile_reader(self): + def read(address, mem_state): + mmap, ptrs, sizes = mem_state + size = sizes[np.searchsorted(ptrs, address)] + ref_data = mmap[address:address + size] + return ref_data + + return Compiler.compile(read, nb.uint8[::1](nb.uint64, self.state_type)) + + diff --git a/ffcv/transforms/color_jitter.py b/ffcv/transforms/color_jitter.py index a79b72fd..9270e805 100644 --- a/ffcv/transforms/color_jitter.py +++ b/ffcv/transforms/color_jitter.py @@ -3,8 +3,13 @@ Reference : https://github.com/pytorch/vision/blob/main/torchvision/transforms/functional_tensor.py ''' +from typing import Callable, Optional, Tuple +import numbers +import math +import random import numpy as np - +from numba import njit +import numba as nb from dataclasses import replace from ..pipeline.allocation_query import AllocationQuery from ..pipeline.operation import Operation @@ -137,3 +142,217 @@ def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255). def declare_state_and_memory(self, previous_state): return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype)) + +# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/colorjitter.py +@njit(parallel=False, fastmath=True, inline="always") +def apply_cj( + im, + apply_bri, + bri_ratio, + apply_cont, + cont_ratio:np.float32, + apply_sat, + sat_ratio:np.float32, + apply_hue, + hue_factor:np.float32, +): + + + gray = ( + np.float32(0.2989) * im[..., 0] + + np.float32(0.5870) * im[..., 1] + + np.float32(0.1140) * im[..., 2] + ) + one = np.float32(1) + # Brightness + if apply_bri: + im = im * bri_ratio + + # Contrast + if apply_cont: + im = cont_ratio * im + (one - cont_ratio) * np.float32(gray.mean()) + + # Saturation + if apply_sat: + im[..., 0] = sat_ratio * im[..., 0] + (one - sat_ratio) * gray + im[..., 1] = sat_ratio * im[..., 1] + (one - sat_ratio) * gray + im[..., 2] = sat_ratio * im[..., 2] + (one - sat_ratio) * gray + + # Hue + if apply_hue: + hue_factor_radians = hue_factor * 2.0 * np.pi + cosA = np.cos(hue_factor_radians) + sinA = np.sin(hue_factor_radians) + v1, v2, v3 = 1.0 / 3.0, np.sqrt(1.0 / 3.0), (1.0 - cosA) + hue_matrix = [ + [ + cosA + v3 / 3.0, + v1 * v3 - v2 * sinA, + v1 * v3 + v2 * sinA, + ], + [ + v1 * v3 + v2 * sinA, + cosA + v1 * v3, + v1 * v3 - v2 * sinA, + ], + [ + v1 * v3 - v2 * sinA, + v1 * v3 + v2 * sinA, + cosA + v1 * v3, + ], + ] + hue_matrix = np.array(hue_matrix, dtype=np.float32).T + for row in nb.prange(im.shape[0]): + im[row] = im[row] @ hue_matrix + return np.clip(im, 0, 255).astype(np.uint8) + + +class RandomColorJitter(Operation): + """Add ColorJitter with probability jitter_prob. + Operates on raw arrays (not tensors). + + see https://github.com/pytorch/vision/blob/28557e0cfe9113a5285330542264f03e4ba74535/torchvision/transforms/functional_tensor.py#L165 + and https://sanje2v.wordpress.com/2021/01/11/accelerating-data-transforms/ + Parameters + ---------- + jitter_prob : float, The probability with which to apply ColorJitter. + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__( + self, + brightness=0.8, + contrast=0.4, + saturation=0.4, + hue=0.2, + p=0.5, + seed=None, + ): + super().__init__() + self.jitter_prob = p + + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5)) + self.seed = seed + assert self.jitter_prob >= 0 and self.jitter_prob <= 1 + + def _check_input( + self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True + ): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f"If {name} is a single number, it must be non negative." + ) + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + else: + raise TypeError( + f"{name} should be a single number or a list/tuple with length 2." + ) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + setattr(self, f"apply_{name}", False) + else: + setattr(self, f"apply_{name}", True) + return tuple(value) + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + + jitter_prob = self.jitter_prob + + apply_bri = self.apply_brightness + bri = self.brightness + + apply_cont = self.apply_contrast + cont = self.contrast + + apply_sat = self.apply_saturation + sat = self.saturation + + apply_hue = self.apply_hue + hue = self.hue + + seed = self.seed + if seed is None: + + def color_jitter(images, _): + for i in my_range(images.shape[0]): + if np.random.rand() > jitter_prob: + continue + + images[i] = apply_cj( + images[i].astype("float32"), + apply_bri, + np.float32(np.random.uniform(bri[0], bri[1])), + apply_cont, + np.float32(np.random.uniform(cont[0], cont[1])), + apply_sat, + np.float32(np.random.uniform(sat[0], sat[1])), + apply_hue, + np.float32(np.random.uniform(hue[0], hue[1])), + ) + return images + + color_jitter.is_parallel = True + return color_jitter + + def color_jitter(images, _, counter): + + random.seed(seed + counter) + N = images.shape[0] + values = np.zeros(N) + bris = np.zeros(N) + conts = np.zeros(N) + sats = np.zeros(N) + hues = np.zeros(N) + for i in range(N): + values[i] = np.float32(random.uniform(0, 1)) + bris[i] = np.float32(random.uniform(bri[0], bri[1])) + conts[i] = np.float32(random.uniform(cont[0], cont[1])) + sats[i] = np.float32(random.uniform(sat[0], sat[1])) + hues[i] = np.float32(random.uniform(hue[0], hue[1])) + for i in my_range(N): + if values[i] > jitter_prob: + continue + images[i] = apply_cj( + images[i].astype("float32"), + apply_bri, + bris[i], + apply_cont, + conts[i], + apply_sat, + sats[i], + apply_hue, + hues[i], + ) + return images + + color_jitter.is_parallel = True + color_jitter.with_counter = True + return color_jitter + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + return (replace(previous_state, jit_mode=True), None) diff --git a/ffcv/transforms/gaussian_blur.py b/ffcv/transforms/gaussian_blur.py new file mode 100644 index 00000000..9aeaf30e --- /dev/null +++ b/ffcv/transforms/gaussian_blur.py @@ -0,0 +1,91 @@ +# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/gaussian_blur.py +import numpy as np +from typing import Callable, Optional, Tuple +from dataclasses import replace +from ffcv.pipeline.allocation_query import AllocationQuery +from ffcv.pipeline.operation import Operation +from ffcv.pipeline.state import State +from ffcv.pipeline.compiler import Compiler +from scipy.signal import convolve2d + + +def apply_blur(img, kernel_size, w): + pad = (kernel_size - 1) // 2 + H, W, _ = img.shape + tmp = np.zeros(img.shape, dtype=np.float32) + for k in range(kernel_size): + start = max(0, pad - k) + stop = min(W, pad - k + W) + window = (img[:, start:stop] / 255) * w[k] + tmp[:, np.abs(stop - W) : W - start] += window + tmp2 = tmp + 0.0 + for k in range(kernel_size): + start = max(0, pad - k) + stop = min(H, pad - k + H) + window = (tmp[start:stop] * w[k]).astype(np.uint8) + tmp2[np.abs(stop - H) : H - start] += window + return np.clip(tmp2 * 255.0, 0, 255).astype(np.uint8) + + +class GaussianBlur(Operation): + """Blurs image with randomly chosen Gaussian blur. + If the image is torch Tensor, it is expected + to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + blur_prob (float): probability to apply blurring to each input + kernel_size (int or sequence): Size of the Gaussian kernel. + sigma (float or tuple of float (min, max)): Standard deviation to be used for + creating kernel to perform blurring. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + """ + + def __init__(self, kernel_size=5, sigma=(0.1, 2.0), p=0.5): + super().__init__() + self.blur_prob = p + self.kernel_size = kernel_size + assert sigma[1] > sigma[0] + self.sigmas = np.linspace(sigma[0], sigma[1], 10) + from scipy import signal + + self.weights = np.stack( + [ + signal.gaussian(kernel_size, s) + for s in np.linspace(sigma[0], sigma[1], 10) + ] + ) + self.weights /= self.weights.sum(1, keepdims=True) + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + blur_prob = self.blur_prob + kernel_size = self.kernel_size + weights = self.weights + apply_blur_c = Compiler.compile(apply_blur) + + def blur(images, indices): + + for i in my_range(images.shape[0]): + if np.random.rand() < blur_prob: + k = np.random.randint(low=0, high=10) + for ch in range(images.shape[-1]): + images[i, ..., ch] = convolve2d( + images[i, ..., ch], + np.outer(weights[k], weights[k]), + mode="same", + ) + # images[i] = apply_blur_c(images[i], kernel_size, weights[k]) + return images + + blur.is_parallel = True + blur.with_indices = True + return blur + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + return ( + replace(previous_state, jit_mode=False), + None, + ) diff --git a/ffcv/transforms/grayscale.py b/ffcv/transforms/grayscale.py new file mode 100644 index 00000000..e98823a1 --- /dev/null +++ b/ffcv/transforms/grayscale.py @@ -0,0 +1,125 @@ +""" +# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/grayscale.py +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" + + +from typing import Callable, Optional, Tuple +from ffcv.pipeline.allocation_query import AllocationQuery +from ffcv.pipeline.operation import Operation +from ffcv.pipeline.state import State +from ffcv.pipeline.compiler import Compiler +from dataclasses import replace +import numpy as np +import random + + +class RandomGrayscale(Operation): + """Add Gaussian Blur with probability blur_prob. + Operates on raw arrays (not tensors). + + Parameters + ---------- + blur_prob : float + The probability with which to flip each image in the batch + horizontally. + """ + + def __init__(self, p: float = 0.2, seed: int = None): + super().__init__() + self.gray_prob = p + self.seed = seed + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + gray_prob = self.gray_prob + seed = self.seed + + if seed is None: + + def grayscale(images, _): + for i in my_range(images.shape[0]): + if np.random.rand() > gray_prob: + continue + images[i] = ( + 0.2989 * images[i, ..., 0:1] + + 0.5870 * images[i, ..., 1:2] + + 0.1140 * images[i, ..., 2:3] + ) + return images + + grayscale.is_parallel = True + return grayscale + + def grayscale(images, _, counter): + random.seed(seed + counter) + values = np.zeros(images.shape[0]) + for i in range(images.shape[0]): + values[i] = random.uniform(0, 1) + for i in my_range(images.shape[0]): + if values[i] > gray_prob: + continue + images[i] = ( + 0.2989 * images[i, ..., 0:1] + + 0.5870 * images[i, ..., 1:2] + + 0.1140 * images[i, ..., 2:3] + ) + return images + + grayscale.with_counter = True + grayscale.is_parallel = True + return grayscale + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + assert previous_state.jit_mode + return (previous_state, None) + + +class LabelGrayscale(Operation): + """ColorJitter info added to the labels. Should be initialized in exactly the same way as + :cla:`ffcv.transforms.ColorJitter`. + """ + + def __init__(self, gray_prob: float = 0.2, seed: int = None): + super().__init__() + self.gray_prob = gray_prob + self.seed = np.random.RandomState(seed).randint(0, 2**32 - 1) + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + gray_prob = self.gray_prob + seed = self.seed + + def grayscale(labels, temp_array, indices): + rep = "" + for i in indices: + rep += str(i) + local_seed = (hash(rep) + seed) % 2**31 + temp_array[:, :-1] = labels + for i in my_range(temp_array.shape[0]): + np.random.seed(local_seed + i) + if np.random.rand() < gray_prob: + temp_array[i, -1] = 0.0 + else: + temp_array[i, -1] = 1.0 + return temp_array + + grayscale.is_parallel = True + grayscale.with_indices = True + + return grayscale + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + previous_shape = previous_state.shape + new_shape = (previous_shape[0] + 1,) + return ( + replace(previous_state, shape=new_shape, dtype=np.float32), + AllocationQuery(new_shape, dtype=np.float32), + ) \ No newline at end of file diff --git a/ffcv/transforms/solarization.py b/ffcv/transforms/solarization.py new file mode 100644 index 00000000..25719536 --- /dev/null +++ b/ffcv/transforms/solarization.py @@ -0,0 +1,119 @@ +# copy from https://github.com/facebookresearch/FFCV-SSL/blob/main/ffcv/transforms/solarization.py +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" + + +from typing import Callable, Optional, Tuple +from ffcv.pipeline.allocation_query import AllocationQuery +from ffcv.pipeline.operation import Operation +from ffcv.pipeline.state import State +from ffcv.pipeline.compiler import Compiler +import numpy as np +from dataclasses import replace +import random + + +class RandomSolarization(Operation): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + Parameters + ---------- + solarization_prob (float): probability of the image being solarized. Default value is 0.5 + threshold (float): all pixels equal or above this value are inverted. + """ + + def __init__( + self, threshold: float = 128, p: float = 0.5, seed: int = None + ): + super().__init__() + self.sol_prob = p + self.threshold = threshold + self.seed = seed + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + sol_prob = self.sol_prob + threshold = self.threshold + seed = self.seed + + if seed is None: + + def solarize(images, _): + for i in my_range(images.shape[0]): + if np.random.rand() < sol_prob: + mask = images[i] >= threshold + images[i] = np.where(mask, 255 - images[i], images[i]) + return images + + solarize.is_parallel = True + return solarize + + def solarize(images, _, counter): + random.seed(seed + counter) + values = np.zeros(len(images)) + for i in range(len(images)): + values[i] = random.uniform(0, 1) + for i in my_range(images.shape[0]): + if values[i] < sol_prob: + mask = images[i] >= threshold + images[i] = np.where(mask, 255 - images[i], images[i]) + return images + + solarize.with_counter = True + solarize.is_parallel = True + return solarize + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + return (replace(previous_state, jit_mode=True), None) + + +class LabelSolarization(Operation): + """ColorJitter info added to the labels. Should be initialized in exactly the same way as + :cla:`ffcv.transforms.ColorJitter`. + """ + + def __init__( + self, solarization_prob: float = 0.5, threshold: float = 128, seed: int = None + ): + super().__init__() + self.solarization_prob = solarization_prob + self.threshold = threshold + self.seed = seed + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + solarization_prob = self.solarization_prob + seed = self.seed + + def solarize(labels, temp_array, indices): + temp_array[:, :-1] = labels + random.seed(seed + indices) + for i in my_range(labels.shape[0]): + if random.uniform(0, 1) < solarization_prob: + temp_array[i, -1] = 1 + else: + temp_array[i, -1] = 0 + return temp_array + + solarize.is_parallel = True + solarize.with_indices = True + + return solarize + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, Optional[AllocationQuery]]: + previous_shape = previous_state.shape + new_shape = (previous_shape[0] + 1,) + return ( + replace(previous_state, shape=new_shape, dtype=np.float32), + AllocationQuery(new_shape, dtype=np.float32), + ) \ No newline at end of file diff --git a/ffcv/writer.py b/ffcv/writer.py index 3615dd0a..20b2574c 100644 --- a/ffcv/writer.py +++ b/ffcv/writer.py @@ -118,6 +118,34 @@ def worker_job_indexed_dataset(input_queue, metadata_sm, metadata_type, fields, allocations_queue.put(allocator.allocations) +def worker_foder_dataset(input_queue, metadata_sm, metadata_type, fields, + allocator, done_number, allocations_queue, dataset): + + metadata = np.frombuffer(metadata_sm.buf, dtype=metadata_type) + field_names = metadata_type.names + + # This `with` block ensures that all the pages allocated have been written + # onto the file + with allocator: + while True: + chunk = input_queue.get() + + if chunk is None: + # No more work left to do + break + + # For each sample in the chunk + for dest_ix, source_ix in chunk: + sample = dataset[source_ix] + handle_sample(sample, dest_ix, field_names, metadata, allocator, fields) + + # We warn the main thread of our progress + with done_number.get_lock(): + done_number.value += len(chunk) + + allocations_queue.put(allocator.allocations) + + class DatasetWriter(): """Writes given dataset into FFCV format (.beton). @@ -318,6 +346,23 @@ def from_webdataset(self, shards: List[str], pipeline: Callable): todos = zip(shards, offsets) self._write_common(total_len, todos, worker_job_webdataset, (pipeline, )) + def from_image_folder(self, root: str): + """Read from image folder. + The images with the same class should be in the same folder. + + Parameters + ---------- + root: str + Path to the folder containing the images. + """ + import glob + from torchvision.datasets.folder import DatasetFolder + classes, class_to_idx = DatasetFolder.find_classes(root) + samples = DatasetFolder.make_dataset(root, class_to_idx) + total_len = len(samples) + self._write_common(total_len, samples, worker_foder_dataset, (class_to_idx, )) + + def finalize(self, allocations) : # Writing metadata diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index db4798d1..93a5e52e 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -7,7 +7,15 @@ #include #include #include +#include #include +#include + +#include +#include +#include +#include + #ifdef _WIN32 typedef unsigned __int32 __uint32_t; typedef unsigned __int64 __uint64_t; @@ -16,11 +24,42 @@ #define EXPORT #endif +// #define _DEBUG +#ifdef _DEBUG +#define DBOUT std::cout // or any other ostream +#else +#define DBOUT 0 && std::cout +#endif + +#include // For std::pair + + +int axis_to_image_boundaries(int a, int img_boundary, int mcuBlock) { + int img_b = img_boundary - (img_boundary % mcuBlock); + int delta_a = a % mcuBlock; + // reduce the a to align the mcu block + if (a > img_b) { + a = img_b; + + } else { + a -= delta_a; + } + return a; +} + +struct Boundaries { + int x; + int y; + int h; + int w; +}; + extern "C" { // a key use to point to the tjtransform instance static pthread_key_t key_tj_transformer; // a key use to point to the tjdecompressor instance static pthread_key_t key_tj_decompressor; + static pthread_key_t key_share_buffer; static pthread_once_t key_once = PTHREAD_ONCE_INIT; // will make the keys to access the tj instances @@ -28,17 +67,39 @@ extern "C" { { pthread_key_create(&key_tj_decompressor, NULL); pthread_key_create(&key_tj_transformer, NULL); + pthread_key_create(&key_share_buffer, NULL); + } + + EXPORT int cv_imdecode(uint8_t* buf, + uint64_t buf_size, + int64_t flag, + uint8_t* output_buffer){ + DBOUT << "imdecode called" << std::endl; + cv::Mat bufArray(1, buf_size, CV_8UC1, buf); + cv::Mat image; + image = cv::imdecode(bufArray, flag); + // Check for failure + if (image.empty()) { + std::cout << "Could not decode the image" << std::endl; + return -1; + }else{ + DBOUT << "Image decoded" << image.rows<<","<