Skip to content

Commit

Permalink
outline, need to fix weights init and try and fix seed
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Byrne committed Jan 29, 2024
1 parent 165537a commit 94ec280
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ discriminator_*/
generator_*/
discriminator_weights_*
generator_weights_*
checkpoint
checkpoint
*.jpg
*.png
50 changes: 48 additions & 2 deletions agrinet/inference.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,59 @@
import argparse

import tensorflow as tf
from utils.LogManager import LogManager
from utils.Model import Generator


def main(args):
logger = LogManager.get_logger("AGRINET INFERENCE")
logger.info("Warming up...")
logger.info("Loading generator...")

generator = Generator()

try:
saved_model = tf.saved_model.load(args.weights)
generator.load_weights(saved_model)
except Exception as e:
logger.critical("Error while loading weights: {}".format(e))

logger.info("Model loaded")

logger.info("Loading input image...")
input = None

try:
input = tf.io.read_file(args.input)
except Exception as e:
logger.critical("Error while reading input image: {}".format(e))

if input is None:
logger.critical("Input image is empty")
else:
logger.info("Input image loaded")

input = tf.image.decode_jpeg(input)
input = tf.image.resize(input, [256, 256])
input = (input / 127.5) - 1 # Normalize the images to [-1, 1]
input = tf.expand_dims(input, axis=0)
input = tf.cast(input, tf.float32)

output = generator(input, training=True)
output = tf.cast(output[0], tf.uint8)
output = tf.image.encode_jpeg(output)

try:
tf.io.write_file("results/output.jpg", output)
except Exception as e:
logger.critical("Error while saving utput image: {}".format(e))


def parse_args():
pass # TODO
parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str, required=True)
parser.add_argument("--input", type=str, required=True)
args = parser.parse_args()
return args


if __name__ == "__main__":
Expand Down

0 comments on commit 94ec280

Please sign in to comment.