Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions dataloaders/davis_2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dataloaders.helpers import *
from torch.utils.data import Dataset
import torch


class DAVIS2016(Dataset):
Expand All @@ -27,6 +28,7 @@ def __init__(self, train=True,
self.transform = transform
self.meanval = meanval
self.seq_name = seq_name
self.data_cache = {}

if self.train:
fname = 'train_seqs'
Expand Down Expand Up @@ -70,23 +72,20 @@ def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img, gt = self.make_img_gt_pair(idx)
sample = self.make_img_gt_pair(idx)

sample = {'image': img, 'gt': gt}

if self.seq_name is not None:
fname = os.path.join(self.seq_name, "%05d" % idx)
sample['fname'] = fname

if self.transform is not None:
sample = self.transform(sample)

return sample

def make_img_gt_pair(self, idx):
"""
Make the image-ground-truth pair
"""

if idx in self.data_cache:
return self.data_cache[idx]

img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]))
if self.labels[idx] is not None:
label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), 0)
Expand All @@ -102,10 +101,23 @@ def make_img_gt_pair(self, idx):
img = np.subtract(img, np.array(self.meanval, dtype=np.float32))

if self.labels[idx] is not None:
gt = np.array(label, dtype=np.float32)
gt = gt/np.max([gt.max(), 1e-8])
gt = np.array(label, dtype=np.float32)
gt = gt / np.max([gt.max(), 1e-8])

sample = {'image': img, 'gt': gt}

if self.seq_name is not None:
fname = os.path.join(self.seq_name, "%05d" % idx)
sample['fname'] = fname

return img, gt
if self.transform is not None:
sample = self.transform(sample)

sample['image'] = torch.tensor(sample['image']).cuda()
sample['gt'] = torch.tensor(sample['gt']).cuda()

self.data_cache[idx] = sample
return sample

def get_img_size(self):
img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))
Expand Down
7 changes: 3 additions & 4 deletions train_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@
tr.ToTensor()])
# Training dataset and its iterator
db_train = db.DAVIS2016(train=True, db_root_dir=db_root_dir, transform=composed_transforms, seq_name=seq_name)
trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=1)
trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True)

# Testing dataset and its iterator
db_test = db.DAVIS2016(train=False, db_root_dir=db_root_dir, transform=tr.ToTensor(), seq_name=seq_name)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)

testloader = DataLoader(db_test, batch_size=1, shuffle=False)

num_img_tr = len(trainloader)
num_img_ts = len(testloader)
Expand All @@ -119,7 +118,7 @@

# Forward-Backward of the mini-batch
inputs.requires_grad_()
inputs, gts = inputs.to(device), gts.to(device)
# inputs, gts = inputs.to(device), gts.to(device)

outputs = net.forward(inputs)

Expand Down