Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ sh scripts/download_model.sh
## Instance Prediction
Please follow the command below to predict all the bounding boxes fo the images in `example` folder.
```
./bbox.sh or
python inference_bbox.py --test_img_dir example
```
All the prediction results would save in `example_bbox` folder.

## Colorize Images
Please follow the command below to colorize all the images in `example` foler.
```
./color.sh or
python test_fusion.py --name test_fusion --sample_p 1.0 --model fusion --fineSize 256 --test_img_dir example --results_img_dir results
```
All the colorized results would save in `results` folder.
Expand All @@ -89,3 +91,6 @@ If you find our code/models useful, please consider citing our paper:

## Acknowledgments
Our code borrows heavily from the amazing [colorization-pytorch](https://github.com/richzhang/colorization-pytorch) repository.


ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation
2 changes: 2 additions & 0 deletions bbox.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python inference_bbox.py --test_img_dir example

7 changes: 7 additions & 0 deletions color.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
python test_fusion.py \
--name test_fusion \
--sample_p 1.0 \
--model fusion \
--fineSize 256 \
--test_img_dir example \
--results_img_dir results
12 changes: 12 additions & 0 deletions inference_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
import torch
from tqdm import tqdm

import pdb

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
# https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl
cfg.MODEL.WEIGHTS = "checkpoints/model_final_2d9806.pkl"

predictor = DefaultPredictor(cfg)

parser = ArgumentParser()
Expand All @@ -42,6 +48,9 @@
lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l_channel, a_channel, b_channel = cv2.split(lab_image)
l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
# l_channel.shape -- (320, 480)
# l_stack.shape -- (320, 480, 3)

outputs = predictor(l_stack)
save_path = join(output_npz_dir, image_path.split('.')[0])
pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy()
Expand All @@ -50,4 +59,7 @@
print('delete {0}'.format(image_path))
os.remove(join(input_dir, image_path))
continue

# (Pdb) pred_bbox.shape -- (10, 4)
# (Pdb) pred_scores.shape -- (10,)
np.savez(save_path, bbox = pred_bbox, scores = pred_scores)
28 changes: 20 additions & 8 deletions models/fusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,32 @@ def save_current_imgs(self, path):
def setup_to_test(self, fusion_weight_path):
GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path)
print('load Fusion model from %s' % GF_path)
GF_state_dict = torch.load(GF_path)
GF_state_dict = torch.load(GF_path, map_location=torch.device('cpu'))

# G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff
G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path)
G_state_dict = torch.load(G_path)
G_state_dict = torch.load(G_path, map_location=torch.device('cpu'))

# GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net
# GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff
GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path)
GComp_state_dict = torch.load(GComp_path)

self.netGF.load_state_dict(GF_state_dict, strict=False)
self.netG.module.load_state_dict(G_state_dict, strict=False)
self.netGComp.module.load_state_dict(GComp_state_dict, strict=False)
GComp_state_dict = torch.load(GComp_path, map_location=torch.device('cpu'))

if (len(self.gpu_ids) > 0):
self.netGF.load_state_dict(GF_state_dict, strict=False)
self.netG.module.load_state_dict(G_state_dict, strict=False)
self.netGComp.module.load_state_dict(GComp_state_dict, strict=False)
else:
# self.netGF
target_state_dict = self.netGF.state_dict()
for n, p in GF_state_dict.items():
n = n.replace('module.', '')
if n in target_state_dict.keys():
target_state_dict[n].copy_(p)
else:
raise KeyError(n)
self.netG.load_state_dict(G_state_dict, strict=False)
self.netGComp.load_state_dict(GComp_state_dict, strict=False)
self.netGF.eval()
self.netG.eval()
self.netGComp.eval()
self.netGComp.eval()
3 changes: 3 additions & 0 deletions project/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@


ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation
21 changes: 17 additions & 4 deletions test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import multiprocessing

import pdb

multiprocessing.set_start_method('spawn', True)

torch.backends.cudnn.benchmark = True
Expand All @@ -36,26 +39,36 @@
model = create_model(opt)
# model.setup_to_test('coco_finetuned_mask_256')
model.setup_to_test('coco_finetuned_mask_256_ffs')
model.eval()

count_empty = 0
for data_raw in tqdm(dataset_loader, dynamic_ncols=True):
# if os.path.isfile(join(save_img_path, data_raw['file_id'][0] + '.png')) is True:
# continue
data_raw['full_img'][0] = data_raw['full_img'][0].cuda()
if (len(opt.gpu_ids) > 0):
data_raw['full_img'][0] = data_raw['full_img'][0].cuda()

if data_raw['empty_box'][0] == 0:
data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
if (len(opt.gpu_ids) > 0):
data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
box_info = data_raw['box_info'][0]
box_info_2x = data_raw['box_info_2x'][0]
box_info_4x = data_raw['box_info_4x'][0]
box_info_8x = data_raw['box_info_8x'][0]
cropped_data = util.get_colorization_data(data_raw['cropped_img'], opt, ab_thresh=0, p=opt.sample_p)
full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)

model.set_input(cropped_data)
model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
model.forward()

with torch.no_grad():
model.forward()
else:
count_empty += 1
full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
model.set_forward_without_box(full_img_data)

with torch.no_grad():
model.set_forward_without_box(full_img_data)

model.save_current_imgs(join(save_img_path, data_raw['file_id'][0] + '.png'))
print('{0} images without bounding boxes'.format(count_empty))