From 0e69cd9548e17a2448d2fc387305be9a34189c80 Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 24 Aug 2020 02:26:54 +0800 Subject: [PATCH] Add Training Data Cache Improve training speed about 2x to 4x faster. --- dataloaders/davis_2016.py | 34 +++++++++++++++++++++++----------- train_online.py | 7 +++---- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/dataloaders/davis_2016.py b/dataloaders/davis_2016.py index 87a1a70..15608c0 100644 --- a/dataloaders/davis_2016.py +++ b/dataloaders/davis_2016.py @@ -7,6 +7,7 @@ from dataloaders.helpers import * from torch.utils.data import Dataset +import torch class DAVIS2016(Dataset): @@ -27,6 +28,7 @@ def __init__(self, train=True, self.transform = transform self.meanval = meanval self.seq_name = seq_name + self.data_cache = {} if self.train: fname = 'train_seqs' @@ -70,16 +72,9 @@ def __len__(self): return len(self.img_list) def __getitem__(self, idx): - img, gt = self.make_img_gt_pair(idx) + sample = self.make_img_gt_pair(idx) - sample = {'image': img, 'gt': gt} - - if self.seq_name is not None: - fname = os.path.join(self.seq_name, "%05d" % idx) - sample['fname'] = fname - if self.transform is not None: - sample = self.transform(sample) return sample @@ -87,6 +82,10 @@ def make_img_gt_pair(self, idx): """ Make the image-ground-truth pair """ + + if idx in self.data_cache: + return self.data_cache[idx] + img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx])) if self.labels[idx] is not None: label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), 0) @@ -102,10 +101,23 @@ def make_img_gt_pair(self, idx): img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) if self.labels[idx] is not None: - gt = np.array(label, dtype=np.float32) - gt = gt/np.max([gt.max(), 1e-8]) + gt = np.array(label, dtype=np.float32) + gt = gt / np.max([gt.max(), 1e-8]) + + sample = {'image': img, 'gt': gt} + + if self.seq_name is not None: + fname = os.path.join(self.seq_name, "%05d" % idx) + sample['fname'] = fname - return img, gt + if self.transform is not None: + sample = self.transform(sample) + + sample['image'] = torch.tensor(sample['image']).cuda() + sample['gt'] = torch.tensor(sample['gt']).cuda() + + self.data_cache[idx] = sample + return sample def get_img_size(self): img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0])) diff --git a/train_online.py b/train_online.py index 00b36cb..34f5fe2 100644 --- a/train_online.py +++ b/train_online.py @@ -94,12 +94,11 @@ tr.ToTensor()]) # Training dataset and its iterator db_train = db.DAVIS2016(train=True, db_root_dir=db_root_dir, transform=composed_transforms, seq_name=seq_name) -trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=1) +trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True) # Testing dataset and its iterator db_test = db.DAVIS2016(train=False, db_root_dir=db_root_dir, transform=tr.ToTensor(), seq_name=seq_name) -testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) - +testloader = DataLoader(db_test, batch_size=1, shuffle=False) num_img_tr = len(trainloader) num_img_ts = len(testloader) @@ -119,7 +118,7 @@ # Forward-Backward of the mini-batch inputs.requires_grad_() - inputs, gts = inputs.to(device), gts.to(device) + # inputs, gts = inputs.to(device), gts.to(device) outputs = net.forward(inputs)