Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions train_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataloaders import davis_2016 as db
from dataloaders import custom_transforms as tr
from util import visualize as viz
import scipy.misc as sm
import imageio
import networks.vgg_osvos as vo
from layers.osvos_layers import class_balanced_cross_entropy_loss
from dataloaders.helpers import *
Expand Down Expand Up @@ -180,13 +180,15 @@

outputs = net.forward(inputs)

for jj in range(int(inputs.size()[0])):
pred = np.transpose(outputs[-1].cpu().data.numpy()[jj, :, :, :], (1, 2, 0))
pred = 1 / (1 + np.exp(-pred))
pred = np.squeeze(pred)
pred = torch.sigmoid(outputs[-1])
pred = pred >= 0.5
pred = 255 * pred
pred = pred.cpu().numpy()

# Save the result, attention to the index jj
sm.imsave(os.path.join(save_dir_res, os.path.basename(fname[jj]) + '.png'), pred)
for jj in range(inputs.size(0)):
imageio.imsave(
os.path.join(save_dir_res, os.path.basename(fname[jj]) + '.png'),
np.transpose(pred[jj], (1, 2, 0)).astype(np.uint8))

if vis_res:
img_ = np.transpose(img.numpy()[jj, :, :, :], (1, 2, 0))
Expand Down