diff --git a/README.md b/README.md
index ce79716..e2ee8da 100644
--- a/README.md
+++ b/README.md
@@ -43,6 +43,11 @@ We list a few projects that use DeepLab2.
* Colab notebook for off-the-shelf inference.
+## Gradio Demo
+
+* Gradio Web Demo
+
+
## Running DeepLab2
See [Getting Started](g3doc/setup/getting_started.md). In short, run the
diff --git a/gradio/demo.py b/gradio/demo.py
new file mode 100644
index 0000000..276d530
--- /dev/null
+++ b/gradio/demo.py
@@ -0,0 +1,273 @@
+import collections
+import os
+import tempfile
+
+from matplotlib import gridspec
+from matplotlib import pyplot as plt
+import numpy as np
+from PIL import Image
+import urllib
+
+import tensorflow as tf
+import gradio as gr
+
+from subprocess import call
+import sys
+
+import requests
+
+url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg'
+r = requests.get(url1, allow_redirects=True)
+open("city1.jpg", 'wb').write(r.content)
+
+url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg'
+r = requests.get(url2, allow_redirects=True)
+open("city2.jpg", 'wb').write(r.content)
+
+
+DatasetInfo = collections.namedtuple(
+ 'DatasetInfo',
+ 'num_classes, label_divisor, thing_list, colormap, class_names')
+
+
+def _cityscapes_label_colormap():
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
+
+ See more about CITYSCAPES dataset at https://www.cityscapes-dataset.com/
+ M. Cordts, et al. "The Cityscapes Dataset for Semantic Urban Scene Understanding." CVPR. 2016.
+
+ Returns:
+ A 2-D numpy array with each row being mapped RGB color (in uint8 range).
+ """
+ colormap = np.zeros((256, 3), dtype=np.uint8)
+ colormap[0] = [128, 64, 128]
+ colormap[1] = [244, 35, 232]
+ colormap[2] = [70, 70, 70]
+ colormap[3] = [102, 102, 156]
+ colormap[4] = [190, 153, 153]
+ colormap[5] = [153, 153, 153]
+ colormap[6] = [250, 170, 30]
+ colormap[7] = [220, 220, 0]
+ colormap[8] = [107, 142, 35]
+ colormap[9] = [152, 251, 152]
+ colormap[10] = [70, 130, 180]
+ colormap[11] = [220, 20, 60]
+ colormap[12] = [255, 0, 0]
+ colormap[13] = [0, 0, 142]
+ colormap[14] = [0, 0, 70]
+ colormap[15] = [0, 60, 100]
+ colormap[16] = [0, 80, 100]
+ colormap[17] = [0, 0, 230]
+ colormap[18] = [119, 11, 32]
+ return colormap
+
+
+def _cityscapes_class_names():
+ return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+
+
+def cityscapes_dataset_information():
+ return DatasetInfo(
+ num_classes=19,
+ label_divisor=1000,
+ thing_list=tuple(range(11, 19)),
+ colormap=_cityscapes_label_colormap(),
+ class_names=_cityscapes_class_names())
+
+
+def perturb_color(color, noise, used_colors, max_trials=50, random_state=None):
+ """Pertrubs the color with some noise.
+
+ If `used_colors` is not None, we will return the color that has
+ not appeared before in it.
+
+ Args:
+ color: A numpy array with three elements [R, G, B].
+ noise: Integer, specifying the amount of perturbing noise (in uint8 range).
+ used_colors: A set, used to keep track of used colors.
+ max_trials: An integer, maximum trials to generate random color.
+ random_state: An optional np.random.RandomState. If passed, will be used to
+ generate random numbers.
+
+ Returns:
+ A perturbed color that has not appeared in used_colors.
+ """
+ if random_state is None:
+ random_state = np.random
+
+ for _ in range(max_trials):
+ random_color = color + random_state.randint(
+ low=-noise, high=noise + 1, size=3)
+ random_color = np.clip(random_color, 0, 255)
+
+ if tuple(random_color) not in used_colors:
+ used_colors.add(tuple(random_color))
+ return random_color
+
+ print('Max trial reached and duplicate color will be used. Please consider '
+ 'increase noise in `perturb_color()`.')
+ return random_color
+
+
+def color_panoptic_map(panoptic_prediction, dataset_info, perturb_noise):
+ """Helper method to colorize output panoptic map.
+
+ Args:
+ panoptic_prediction: A 2D numpy array, panoptic prediction from deeplab
+ model.
+ dataset_info: A DatasetInfo object, dataset associated to the model.
+ perturb_noise: Integer, the amount of noise (in uint8 range) added to each
+ instance of the same semantic class.
+
+ Returns:
+ colored_panoptic_map: A 3D numpy array with last dimension of 3, colored
+ panoptic prediction map.
+ used_colors: A dictionary mapping semantic_ids to a set of colors used
+ in `colored_panoptic_map`.
+ """
+ if panoptic_prediction.ndim != 2:
+ raise ValueError('Expect 2-D panoptic prediction. Got {}'.format(
+ panoptic_prediction.shape))
+
+ semantic_map = panoptic_prediction // dataset_info.label_divisor
+ instance_map = panoptic_prediction % dataset_info.label_divisor
+ height, width = panoptic_prediction.shape
+ colored_panoptic_map = np.zeros((height, width, 3), dtype=np.uint8)
+
+ used_colors = collections.defaultdict(set)
+ # Use a fixed seed to reproduce the same visualization.
+ random_state = np.random.RandomState(0)
+
+ unique_semantic_ids = np.unique(semantic_map)
+ for semantic_id in unique_semantic_ids:
+ semantic_mask = semantic_map == semantic_id
+ if semantic_id in dataset_info.thing_list:
+ # For `thing` class, we will add a small amount of random noise to its
+ # correspondingly predefined semantic segmentation colormap.
+ unique_instance_ids = np.unique(instance_map[semantic_mask])
+ for instance_id in unique_instance_ids:
+ instance_mask = np.logical_and(semantic_mask,
+ instance_map == instance_id)
+ random_color = perturb_color(
+ dataset_info.colormap[semantic_id],
+ perturb_noise,
+ used_colors[semantic_id],
+ random_state=random_state)
+ colored_panoptic_map[instance_mask] = random_color
+ else:
+ # For `stuff` class, we use the defined semantic color.
+ colored_panoptic_map[semantic_mask] = dataset_info.colormap[semantic_id]
+ used_colors[semantic_id].add(tuple(dataset_info.colormap[semantic_id]))
+ return colored_panoptic_map, used_colors
+
+
+def vis_segmentation(image,
+ panoptic_prediction,
+ dataset_info,
+ perturb_noise=60):
+ """Visualizes input image, segmentation map and overlay view."""
+ plt.figure(figsize=(30, 20))
+ grid_spec = gridspec.GridSpec(2, 2)
+
+ ax = plt.subplot(grid_spec[0])
+ plt.imshow(image)
+ plt.axis('off')
+ ax.set_title('input image', fontsize=20)
+
+ ax = plt.subplot(grid_spec[1])
+ panoptic_map, used_colors = color_panoptic_map(panoptic_prediction,
+ dataset_info, perturb_noise)
+ plt.imshow(panoptic_map)
+ plt.axis('off')
+ ax.set_title('panoptic map', fontsize=20)
+
+ ax = plt.subplot(grid_spec[2])
+ plt.imshow(image)
+ plt.imshow(panoptic_map, alpha=0.7)
+ plt.axis('off')
+ ax.set_title('panoptic overlay', fontsize=20)
+
+ ax = plt.subplot(grid_spec[3])
+ max_num_instances = max(len(color) for color in used_colors.values())
+ # RGBA image as legend.
+ legend = np.zeros((len(used_colors), max_num_instances, 4), dtype=np.uint8)
+ class_names = []
+ for i, semantic_id in enumerate(sorted(used_colors)):
+ legend[i, :len(used_colors[semantic_id]), :3] = np.array(
+ list(used_colors[semantic_id]))
+ legend[i, :len(used_colors[semantic_id]), 3] = 255
+ if semantic_id < dataset_info.num_classes:
+ class_names.append(dataset_info.class_names[semantic_id])
+ else:
+ class_names.append('ignore')
+
+ plt.imshow(legend, interpolation='nearest')
+ ax.yaxis.tick_left()
+ plt.yticks(range(len(legend)), class_names, fontsize=15)
+ plt.xticks([], [])
+ ax.tick_params(width=0.0, grid_linewidth=0.0)
+ plt.grid('off')
+ return plt
+
+def run_cmd(command):
+ try:
+ print(command)
+ call(command, shell=True)
+ except KeyboardInterrupt:
+ print("Process interrupted")
+ sys.exit(1)
+MODEL_NAME = 'resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model'
+
+
+_MODELS = ('resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model',
+ 'resnet50_beta_os32_panoptic_deeplab_cityscapes_trainfine_saved_model',
+ 'wide_resnet41_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
+ 'swidernet_sac_1_1_1_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
+ 'swidernet_sac_1_1_3_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
+ 'swidernet_sac_1_1_4.5_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
+ 'axial_swidernet_1_1_1_os16_axial_deeplab_cityscapes_trainfine_saved_model',
+ 'axial_swidernet_1_1_3_os16_axial_deeplab_cityscapes_trainfine_saved_model',
+ 'axial_swidernet_1_1_4.5_os16_axial_deeplab_cityscapes_trainfine_saved_model',
+ 'max_deeplab_s_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model',
+ 'max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model')
+_DOWNLOAD_URL_PATTERN = 'https://storage.googleapis.com/gresearch/tf-deeplab/saved_model/%s.tar.gz'
+
+_MODEL_NAME_TO_URL_AND_DATASET = {
+ model: (_DOWNLOAD_URL_PATTERN % model, cityscapes_dataset_information())
+ for model in _MODELS
+}
+
+MODEL_URL, DATASET_INFO = _MODEL_NAME_TO_URL_AND_DATASET[MODEL_NAME]
+
+model_dir = tempfile.mkdtemp()
+
+download_path = os.path.join(model_dir, MODEL_NAME + '.gz')
+urllib.request.urlretrieve(MODEL_URL, download_path)
+
+run_cmd("tar -xzvf " + download_path + " -C " + model_dir)
+
+LOADED_MODEL = tf.saved_model.load(os.path.join(model_dir, MODEL_NAME))
+def inference(image):
+ image = image.resize(size=(512, 512))
+ im = np.array(image)
+ output = LOADED_MODEL(tf.cast(im, tf.uint8))
+ return vis_segmentation(im, output['panoptic_pred'][0], DATASET_INFO)
+
+title = "Deeplab2"
+description = "demo for Deeplab2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
+article = "
DeepLab2: A TensorFlow Library for Deep Labeling | Github Repo
" + +gr.Interface( + inference, + [gr.inputs.Image(type="pil", label="Input")], + gr.outputs.Image(type="plot", label="Output"), + title=title, + description=description, + article=article, + examples=[ + ["city1.jpg"], + ["city2.jpg"] + ]).launch() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..17b9290 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +matplotlib +numpy +Pillow +tensorflow +gradio