Skip to content

Commit

Permalink
Merge branch 'feature/parser' into developer
Browse files Browse the repository at this point in the history
  • Loading branch information
raulbonet committed Aug 2, 2020
2 parents 55f9239 + a481e7f commit f5874f1
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions src/main-transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ def parse_f():
parser = argparse.ArgumentParser(description='Neural Style Transfer')
parser.add_argument('--content_im_path', default=c.EX_IM_PATH, help='path of input image, including filename')
parser.add_argument('--output_im_path', default=c.EX_OUT_PATH, help='path of the folder in which the output image will be stored')
parser.add_argument('--content_layer', default=c.CONTENT_LAYER_NAME, help='content layer name whose feature maps will be used for loss computation. Read Readme for further information')
parser.add_argument('--style_layer', default=c.STYLE_LAYER_NAMES, help='style layer name(s) whose feature maps will be used for loss computation. Read Readme for further information')
parser.add_argument('--style_layer_weights', default=c.STYLE_WEIGHTS, help='weights for each style layer')
parser.add_argument('--content_loss_weight', default=c.ALPHA, help='weight of content loss on loss function')
parser.add_argument('--style_loss_weight', default=c.BETA, help='weight of style loss on loss function')
parser.add_argument('--epochs', default=c.EPOCHS, help='number of iterations when training')
return parser


Expand All @@ -23,15 +29,15 @@ def layer_dims_f(model, layer_names):
return num_kernels, dim_kernels


def loss_f(content_ext_model, style_ext_models, content_ref_fmap, style_ref_grams, num_kernels, dim_kernels, gen_im):
def loss_f(content_ext_model, style_ext_models, content_ref_fmap, style_ref_grams, num_kernels, dim_kernels, gen_im,
style_layer_weights, content_loss_weight, style_loss_weight):

def _loss_f():
gen_im_preproc = tf.keras.applications.vgg19.preprocess_input(gen_im)

content_gen_fmap = content_ext_model(gen_im_preproc)
content_loss = 0.5 * tf.math.reduce_sum(tf.math.square(content_ref_fmap - content_gen_fmap))

# style_gen_grams = style_grams_f(model_vgg, gen_im_preproc)
style_gen_grams = style_grams_f(style_ext_models, gen_im_preproc)
diff_styles = [tf.math.square(a - b) for a, b in zip(style_ref_grams, style_gen_grams)]
diff_styles_reduced = tf.TensorArray(dtype='float32', size=len(diff_styles))
Expand All @@ -40,9 +46,9 @@ def _loss_f():
diff_styles_reduced = diff_styles_reduced.stack()

style_loss = 1. / ((2 * num_kernels * dim_kernels) ** 2) * diff_styles_reduced
style_loss = tf.tensordot(c.STYLE_WEIGHTS, style_loss, axes=1)
style_loss = tf.tensordot(style_layer_weights, style_loss, axes=1)

total_loss = c.ALPHA * content_loss + c.BETA * style_loss
total_loss = content_loss_weight * content_loss + style_loss_weight * style_loss
return total_loss

return _loss_f
Expand Down Expand Up @@ -70,22 +76,22 @@ def load_image_tf_f(path):
return im


def main():
def main(**kwargs):

# Load content image
content_im = load_image_tf_f(c.EX_IM_PATH)
content_im = load_image_tf_f(kwargs['content_im_path'])
content_im_preproc = tf.keras.applications.vgg19.preprocess_input(content_im)

# VGG model
_, width, height, channels = content_im.shape
model_vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_tensor=None,
input_shape=(width, height, 3), pooling='max')
model_vgg.trainable = False
num_kernels, dim_kernels = layer_dims_f(model_vgg, c.STYLE_LAYER_NAMES)
num_kernels, dim_kernels = layer_dims_f(model_vgg, kwargs['style_layer'])

# Content loss: model to compute feature maps
content_ext_model = tf.keras.Model(inputs=model_vgg.inputs,
outputs=model_vgg.get_layer(c.CONTENT_LAYER_NAME).output)
outputs=model_vgg.get_layer(kwargs['content_layer']).output)

# Content loss: feature maps of input (reference) image
content_ref_fmap = content_ext_model(content_im_preproc)
Expand All @@ -96,7 +102,7 @@ def main():

# Style loss: gram matrix of input (reference) image
style_ext_models = [tf.keras.Model(inputs=model_vgg.inputs, outputs=model_vgg.get_layer(layer).output) \
for layer in c.STYLE_LAYER_NAMES]
for layer in kwargs['style_layer']]
style_ref_grams = style_grams_f(style_ext_models, style_im_preproc)

# Create noise image
Expand All @@ -110,26 +116,21 @@ def main():
# Optimizer
opt = tf.keras.optimizers.Adam(learning_rate=1)

for epoch in range(c.EPOCHS):
for epoch in range(kwargs['epochs']):
opt.minimize(
loss=loss_f(content_ext_model, style_ext_models, content_ref_fmap, style_ref_grams, num_kernels, dim_kernels, gen_im),
loss=loss_f(content_ext_model, style_ext_models, content_ref_fmap, style_ref_grams, num_kernels, dim_kernels, gen_im,
kwargs['style_layer_weights'], kwargs['content_loss_weight'], kwargs['style_loss_weight']),
var_list=[gen_im])
if epoch % 5 == 0:
save_im = gen_im.numpy()[0, :, :, :]
save_im = np.where(save_im <= 255, save_im, 255)
save_im = np.where(save_im >= 0, save_im, 0)
save_im = save_im.astype(np.uint8)
save_im = Image.fromarray(save_im)
save_im.save(os.path.join(c.EX_OUT_PATH, 'output-example' + str(epoch) + '.jpg'))
save_im.save(os.path.join(kwargs['output_im_path'], 'output-example' + str(epoch) + '.jpg'))


if __name__ == '__main__':
# args = parse_f().parse_args()
# args_dict = args.__dict__
# args_dict['content_layer_name'] = c.CONTENT_LAYER_NAME
# args_dict['style_layer_names'] = c.STYLE_LAYER_NAMES
# args_dict['style_weights'] = c.STYLE_WEIGHTS
# args_dict['alpha'] = c.ALPHA
# args_dict['beta'] = c.BETA
# main(**args_dict)
main()
args = parse_f().parse_args()
kwargs = args.__dict__
main(**kwargs)

0 comments on commit f5874f1

Please sign in to comment.