Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions pix2pix/src/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def lambda_output(input_shape):
return input_shape[:2]


# def conv_block_unet(x, f, name, bn_mode, bn_axis, bn=True, dropout=False, strides=(2,2)):
# def conv_block_unet(x, f, name, bn_axis, bn=True, dropout=False, strides=(2,2)):

# x = Conv2D(f, (3, 3), strides=strides, name=name, padding="same")(x)
# if bn:
Expand All @@ -33,7 +33,7 @@ def lambda_output(input_shape):
# return x


# def up_conv_block_unet(x1, x2, f, name, bn_mode, bn_axis, bn=True, dropout=False):
# def up_conv_block_unet(x1, x2, f, name, bn_axis, bn=True, dropout=False):

# x1 = UpSampling2D(size=(2, 2))(x1)
# x = merge([x1, x2], mode="concat", concat_axis=bn_axis)
Expand All @@ -47,7 +47,7 @@ def lambda_output(input_shape):

# return x

def conv_block_unet(x, f, name, bn_mode, bn_axis, bn=True, strides=(2,2)):
def conv_block_unet(x, f, name, bn_axis, bn=True, strides=(2,2)):

x = LeakyReLU(0.2)(x)
x = Conv2D(f, (3, 3), strides=strides, name=name, padding="same")(x)
Expand All @@ -57,7 +57,7 @@ def conv_block_unet(x, f, name, bn_mode, bn_axis, bn=True, strides=(2,2)):
return x


def up_conv_block_unet(x, x2, f, name, bn_mode, bn_axis, bn=True, dropout=False):
def up_conv_block_unet(x, x2, f, name, bn_axis, bn=True, dropout=False):

x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
Expand All @@ -71,7 +71,7 @@ def up_conv_block_unet(x, x2, f, name, bn_mode, bn_axis, bn=True, dropout=False)
return x


def deconv_block_unet(x, x2, f, h, w, batch_size, name, bn_mode, bn_axis, bn=True, dropout=False):
def deconv_block_unet(x, x2, f, h, w, batch_size, name, bn_axis, bn=True, dropout=False):

o_shape = (batch_size, h * 2, w * 2, f)
x = Activation("relu")(x)
Expand All @@ -85,7 +85,7 @@ def deconv_block_unet(x, x2, f, h, w, batch_size, name, bn_mode, bn_axis, bn=Tru
return x


def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsampling"):
def generator_unet_upsampling(img_dim, model_name="generator_unet_upsampling"):

nb_filters = 64

Expand All @@ -109,7 +109,7 @@ def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsam
strides=(2, 2), name="unet_conv2D_1", padding="same")(unet_input)]
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_conv2D_%s" % (i + 2)
conv = conv_block_unet(list_encoder[-1], f, name, bn_mode, bn_axis)
conv = conv_block_unet(list_encoder[-1], f, name, bn_axis)
list_encoder.append(conv)

# Prepare decoder filters
Expand All @@ -119,15 +119,15 @@ def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsam

# Decoder
list_decoder = [up_conv_block_unet(list_encoder[-1], list_encoder[-2],
list_nb_filters[0], "unet_upconv2D_1", bn_mode, bn_axis, dropout=True)]
list_nb_filters[0], "unet_upconv2D_1", bn_axis, dropout=True)]
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_upconv2D_%s" % (i + 2)
# Dropout only on first few layers
if i < 2:
d = True
else:
d = False
conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i + 3)], f, name, bn_mode, bn_axis, dropout=d)
conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i + 3)], f, name, bn_axis, dropout=d)
list_decoder.append(conv)

x = Activation("relu")(list_decoder[-1])
Expand All @@ -140,7 +140,7 @@ def generator_unet_upsampling(img_dim, bn_mode, model_name="generator_unet_upsam
return generator_unet


def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_unet_deconv"):
def generator_unet_deconv(img_dim, batch_size, model_name="generator_unet_deconv"):

assert K.backend() == "tensorflow", "Not implemented with theano backend"

Expand All @@ -162,7 +162,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
h, w = h / 2, w / 2
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_conv2D_%s" % (i + 2)
conv = conv_block_unet(list_encoder[-1], f, name, bn_mode, bn_axis)
conv = conv_block_unet(list_encoder[-1], f, name, bn_axis)
list_encoder.append(conv)
h, w = h / 2, w / 2

Expand All @@ -174,7 +174,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
# Decoder
list_decoder = [deconv_block_unet(list_encoder[-1], list_encoder[-2],
list_nb_filters[0], h, w, batch_size,
"unet_upconv2D_1", bn_mode, bn_axis, dropout=True)]
"unet_upconv2D_1", bn_axis, dropout=True)]
h, w = h * 2, w * 2
for i, f in enumerate(list_nb_filters[1:]):
name = "unet_upconv2D_%s" % (i + 2)
Expand All @@ -184,7 +184,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
else:
d = False
conv = deconv_block_unet(list_decoder[-1], list_encoder[-(i + 3)], f, h,
w, batch_size, name, bn_mode, bn_axis, dropout=d)
w, batch_size, name, bn_axis, dropout=d)
list_decoder.append(conv)
h, w = h * 2, w * 2

Expand All @@ -198,7 +198,7 @@ def generator_unet_deconv(img_dim, bn_mode, batch_size, model_name="generator_un
return generator_unet


def DCGAN_discriminator(img_dim, nb_patch, bn_mode, model_name="DCGAN_discriminator", use_mbd=True):
def DCGAN_discriminator(img_dim, nb_patch, model_name="DCGAN_discriminator", use_mbd=True):
"""
Discriminator model of the DCGAN

Expand Down Expand Up @@ -304,24 +304,24 @@ def DCGAN(generator, discriminator_model, img_dim, patch_size, image_dim_orderin
return DCGAN


def load(model_name, img_dim, nb_patch, bn_mode, use_mbd, batch_size):
def load(model_name, img_dim, nb_patch, use_mbd, batch_size):

if model_name == "generator_unet_upsampling":
model = generator_unet_upsampling(img_dim, bn_mode, model_name=model_name)
model = generator_unet_upsampling(img_dim, model_name=model_name)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file="../../figures/%s.png" % model_name, show_shapes=True, show_layer_names=True)
return model

if model_name == "generator_unet_deconv":
model = generator_unet_deconv(img_dim, bn_mode, batch_size, model_name=model_name)
model = generator_unet_deconv(img_dim, batch_size, model_name=model_name)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file="../../figures/%s.png" % model_name, show_shapes=True, show_layer_names=True)
return model

if model_name == "DCGAN_discriminator":
model = DCGAN_discriminator(img_dim, nb_patch, bn_mode, model_name=model_name, use_mbd=use_mbd)
model = DCGAN_discriminator(img_dim, nb_patch, model_name=model_name, use_mbd=use_mbd)
model.summary()
from keras.utils import plot_model
plot_model(model, to_file="../../figures/%s.png" % model_name, show_shapes=True, show_layer_names=True)
Expand All @@ -331,4 +331,4 @@ def load(model_name, img_dim, nb_patch, bn_mode, use_mbd, batch_size):
if __name__ == "__main__":

# load("generator_unet_deconv", (256, 256, 3), 16, 2, False, 32)
load("generator_unet_upsampling", (256, 256, 3), 16, 2, False, 32)
load("generator_unet_upsampling", (256, 256, 3), 16, False, 32)