From 94ec280b360e43cbe5b63e65556f42847e757d35 Mon Sep 17 00:00:00 2001 From: Adam Byrne Date: Mon, 29 Jan 2024 23:22:43 +0000 Subject: [PATCH] outline, need to fix weights init and try and fix seed --- .gitignore | 4 +++- agrinet/inference.py | 50 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index a20d967..83e11ca 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,6 @@ discriminator_*/ generator_*/ discriminator_weights_* generator_weights_* -checkpoint \ No newline at end of file +checkpoint +*.jpg +*.png \ No newline at end of file diff --git a/agrinet/inference.py b/agrinet/inference.py index f257eca..d90f994 100644 --- a/agrinet/inference.py +++ b/agrinet/inference.py @@ -1,13 +1,59 @@ +import argparse + +import tensorflow as tf from utils.LogManager import LogManager +from utils.Model import Generator def main(args): logger = LogManager.get_logger("AGRINET INFERENCE") - logger.info("Warming up...") + logger.info("Loading generator...") + + generator = Generator() + + try: + saved_model = tf.saved_model.load(args.weights) + generator.load_weights(saved_model) + except Exception as e: + logger.critical("Error while loading weights: {}".format(e)) + + logger.info("Model loaded") + + logger.info("Loading input image...") + input = None + + try: + input = tf.io.read_file(args.input) + except Exception as e: + logger.critical("Error while reading input image: {}".format(e)) + + if input is None: + logger.critical("Input image is empty") + else: + logger.info("Input image loaded") + + input = tf.image.decode_jpeg(input) + input = tf.image.resize(input, [256, 256]) + input = (input / 127.5) - 1 # Normalize the images to [-1, 1] + input = tf.expand_dims(input, axis=0) + input = tf.cast(input, tf.float32) + + output = generator(input, training=True) + output = tf.cast(output[0], tf.uint8) + output = tf.image.encode_jpeg(output) + + try: + tf.io.write_file("results/output.jpg", output) + except Exception as e: + logger.critical("Error while saving utput image: {}".format(e)) def parse_args(): - pass # TODO + parser = argparse.ArgumentParser() + parser.add_argument("--weights", type=str, required=True) + parser.add_argument("--input", type=str, required=True) + args = parser.parse_args() + return args if __name__ == "__main__":