From 5b8748d6f459c976393433572fb1ec48dd8ae777 Mon Sep 17 00:00:00 2001 From: Tim Meinhardt Date: Fri, 21 Jun 2019 14:11:44 +0200 Subject: [PATCH] Fix deprecation warning of scipy.misc.imsave and greyscale seg masks --- train_online.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/train_online.py b/train_online.py index 00b36cb..9ab2dea 100644 --- a/train_online.py +++ b/train_online.py @@ -17,7 +17,7 @@ from dataloaders import davis_2016 as db from dataloaders import custom_transforms as tr from util import visualize as viz -import scipy.misc as sm +import imageio import networks.vgg_osvos as vo from layers.osvos_layers import class_balanced_cross_entropy_loss from dataloaders.helpers import * @@ -180,13 +180,15 @@ outputs = net.forward(inputs) - for jj in range(int(inputs.size()[0])): - pred = np.transpose(outputs[-1].cpu().data.numpy()[jj, :, :, :], (1, 2, 0)) - pred = 1 / (1 + np.exp(-pred)) - pred = np.squeeze(pred) + pred = torch.sigmoid(outputs[-1]) + pred = pred >= 0.5 + pred = 255 * pred + pred = pred.cpu().numpy() - # Save the result, attention to the index jj - sm.imsave(os.path.join(save_dir_res, os.path.basename(fname[jj]) + '.png'), pred) + for jj in range(inputs.size(0)): + imageio.imsave( + os.path.join(save_dir_res, os.path.basename(fname[jj]) + '.png'), + np.transpose(pred[jj], (1, 2, 0)).astype(np.uint8)) if vis_res: img_ = np.transpose(img.numpy()[jj, :, :, :], (1, 2, 0))