Skip to content

Commit

Permalink
Merge branch 'develop' into agTest
Browse files Browse the repository at this point in the history
  • Loading branch information
emi05h authored Feb 1, 2024
2 parents 435bbda + acd9ccf commit 45090d2
Show file tree
Hide file tree
Showing 13 changed files with 709 additions and 35 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/agrinet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: AgriNet CI

on:
push:
branches:
- main
- develop

pull_request:
branches:
- main
- develop

jobs:
runtests:
name: pip
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run tests
run: python -m unittest discover -q agrinet/tests
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# AI-utils

[![AgriNet CI](https://github.com/ULAS-HiPR/AI-utils/actions/workflows/agrinet.yml/badge.svg)](https://github.com/ULAS-HiPR/AI-utils/actions/workflows/agrinet.yml)

Central repository for all AI/ML scripts, notebooks, and experiments.

## Getting Started
Expand All @@ -15,6 +17,11 @@ conda env create -f environment.yml # if you don't have the conda env created
conda activate ai-utils
```

```bash
# if you need to use pip
pip install -r requirements.txt
```

_Note: Conda is used because it let's use pip packages also and we can control the python version. It is a bit more powerful and is easier to work with in the long run. Install either [anaconda](https://www.anaconda.com/download/) or [miniconda](https://docs.conda.io/projects/miniconda/en/latest/miniconda-install.html) to use it._

## Directory Structure
Expand Down
52 changes: 52 additions & 0 deletions agrinet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# AgriNet

[![AgriNet CI](https://github.com/ULAS-HiPR/AI-utils/actions/workflows/agrinet.yml/badge.svg)](https://github.com/ULAS-HiPR/AI-utils/actions/workflows/agrinet.yml)

A RGB to NIR image translation model for agricultural aerial evaluation on vegetation and moisture data.

## Datasets

nirscene0 - 477 images
multiple scenes ranging in conditions, this was used as it is a good balance for generalisation.

## Tools

### Training

CLI tool available via `python train.py --help` for more information.

```man
usage: train.py [-h] --name NAME --data_dir DATA_DIR --batch_size BATCH_SIZE
--epochs EPOCHS [--lr LR] [--ext EXT] [--seed SEED]
optional arguments:
-h, --help show this help message and exit
--name NAME Name of experiment, used for logging and saving
checkpoints and weights in one directory
--data_dir DATA_DIR Path to data directory. must contain train and test
folders with images (see --ext)
--batch_size BATCH_SIZE
Batch size for training
--epochs EPOCHS Number of epochs to train
--lr LR Learning rate for training
--ext EXT Extension of the images
--seed SEED Random seed for training
```

### Tensorboard

For monitoring loss during training and viewing input and output space

```bash
agrinet $ tensorboard --logdir={NAME}/logs
```

**Example view**

![TB example view](../assets/image.png)

### Unit tests

```bash
agrinet $ python -m unittest discover -q tests
```
Empty file added agrinet/__init__.py
Empty file.
Empty file added agrinet/tests/__init__.py
Empty file.
57 changes: 57 additions & 0 deletions agrinet/tests/testDataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys
import os

# fixes "ModuleNotFoundError: No module named 'utils'"
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# flake8: noqa
import unittest
import tensorflow as tf
from utils.DataLoader import load_image_test, load_image_train, resize


class TestImageDataFunctions(unittest.TestCase):
def setUp(self):
# Set up any necessary resources or configurations for tests
pass

def tearDown(self):
# Clean up after tests
pass

def test_load_image_train(self):
# Condition: input_image.shape != real_image.shape from the train dataset
image_file = "assets/green-field.jpg"
input_image, real_image = load_image_train(image_file)

self.assertIsInstance(input_image, tf.Tensor)
self.assertIsInstance(real_image, tf.Tensor)
self.assertEqual(input_image.shape, real_image.shape)
self.assertEqual(input_image.shape, (256, 256, 3))

def test_load_image_test(self):
# Condition: input_image.shape == real_image.shape from the test dataset
image_file = "assets/green-field.jpg"

input_image, real_image = load_image_test(image_file)

self.assertIsInstance(input_image, tf.Tensor)
self.assertIsInstance(real_image, tf.Tensor)
self.assertEqual(input_image.shape, real_image.shape)
self.assertEqual(input_image.shape, (256, 256, 3))

def test_resize(self):
# Condition: input_image.shape == real_image.shape
input_image = tf.constant([[[1, 2, 3], [4, 5, 6]]], dtype=tf.float32)
real_image = tf.constant([[[7, 8, 9], [10, 11, 12]]], dtype=tf.float32)
height, width = 128, 128
resized_input, resized_real = resize(input_image, real_image, height, width)

self.assertIsInstance(resized_input, tf.Tensor)
self.assertIsInstance(resized_real, tf.Tensor)
self.assertEqual(resized_input.shape, resized_real.shape)
self.assertEqual(resized_input.shape, (128, 128, 3))


if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions agrinet/tests/testLogManager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import sys
import os

# fixes "ModuleNotFoundError: No module named 'utils'"
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# flake8: noqa
import unittest
from io import StringIO
from unittest.mock import patch
from utils.LogManager import LogManager


class TestLogManager(unittest.TestCase):
def setUp(self):
# Redirect stdout to capture log messages
self.mock_stdout = StringIO()
patch("sys.stdout", self.mock_stdout).start()

def tearDown(self):
# Clean up and restore stdout
patch.stopall()

def test_singleton_instance(self):
with self.assertRaises(Exception) as context:
log_manager1 = LogManager()
log_manager2 = LogManager()
del log_manager1, log_manager2

self.assertEqual(str(context.exception), "This class is a singleton!")

def test_get_logger(self):
log_manager = LogManager()
logger = log_manager.get_logger("test_logger")
self.assertEqual(logger.name, "utils.LogManager")


if __name__ == "__main__":
unittest.main()
58 changes: 40 additions & 18 deletions agrinet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import time

import tensorflow as tf
from utils.DataLoader import load_image_train, load_image_test
from tensorflow.summary import create_file_writer
from utils.DataLoader import load_image_test, load_image_train
from utils.LogManager import LogManager
from utils.Model import (
Discriminator,
Expand All @@ -20,6 +21,10 @@
def main(args):
logger = LogManager.get_logger("AGRINET TRAIN")

# training logs
log_dir = os.path.join("./", args.name, "logs")
summary_writer = create_file_writer(log_dir)

# Gathering training
logger.info("Building data pipeline...")

Expand Down Expand Up @@ -51,7 +56,7 @@ def train_step(input_image, target, step):
)

gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
disc_generated_output, gen_output, target
disc_generated_output, gen_output, target, args.lr
)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

Expand All @@ -74,12 +79,17 @@ def train_step(input_image, target, step):
txt_gen_l1_loss = tf.convert_to_tensor(gen_l1_loss)
txt_disc_loss = tf.convert_to_tensor(disc_loss)

tf.summary.scalar("gen_total_loss", txt_gen_total_loss, step=step)
tf.summary.scalar("gen_gan_loss", txt_gen_gan_loss, step=step)
tf.summary.scalar("gen_l1_loss", txt_gen_l1_loss, step=step)
tf.summary.scalar("disc_loss", txt_disc_loss, step=step)
with summary_writer.as_default():
tf.summary.scalar("gen_total_loss", txt_gen_total_loss, step=step)
tf.summary.scalar("gen_gan_loss", txt_gen_gan_loss, step=step)
tf.summary.scalar("gen_l1_loss", txt_gen_l1_loss, step=step)
tf.summary.scalar("disc_loss", txt_disc_loss, step=step)

tf.summary.image("input_image", input_image, step=step)
tf.summary.image("target", target, step=step)
tf.summary.image("gen_output", gen_output, step=step)

checkpoint_dir = "./training_checkpoints"
checkpoint_dir = f"./{args.name}/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(
generator_optimizer=generator_optimizer,
Expand Down Expand Up @@ -114,30 +124,42 @@ def fit(train_ds, test_ds, steps):

# day-month-year-hour-minute
filetime = time.strftime("%d%m%y_%H%M")
tf.saved_model.save(generator, "./generator_{}".format(filetime))
tf.saved_model.save(discriminator, "./discriminator_{}".format(filetime))

try:
generator.save_weights("./generator_weights_{}".format(filetime))
discriminator.save_weights("./discriminator_weights_{}".format(filetime))

except Exception as e:
logger.error("Error while saving weights : {}".format(e))
tf.saved_model.save(generator, "./{}/generator_{}".format(args.name, filetime))
tf.saved_model.save(
discriminator, "./{}/discriminator_{}".format(args.name, filetime)
)

logger.debug(
"Model and weights saved at {} and {} respectively".format(
"./generator_{} ".format(filetime), " ./discriminator_{}".format(filetime)
"./{}/generator_{} ".format(args.name, filetime),
" ./{}/discriminator_{}".format(args.name, filetime),
)
)

try:
generator.save_weights("./{}/generator_weights_{}".format(args.name, filetime))
discriminator.save_weights(
"./{}/discriminator_weights_{}".format(args.name, filetime)
)

except Exception as e:
logger.error("Error while saving weights : {}".format(e))


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--name",
type=str,
help="Name of experiment, used for logging and saving checkpoints and weights",
required=True,
)
parser.add_argument(
"--data_dir",
type=str,
default="data",
help="Path to data directory",
help="Path to data directory. must contain train and test folders with images",
required=True,
)
parser.add_argument(
Expand All @@ -155,7 +177,7 @@ def parse_args():
required=True,
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="Learning rate for training"
"--lr", type=float, default=100, help="Learning rate for training"
)
parser.add_argument(
"--ext", type=str, default="jpg", help="Extension of the images"
Expand Down
43 changes: 30 additions & 13 deletions agrinet/utils/LogManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,8 @@ class LogManager:
def get_logger(name=None):
if LogManager.__instance is None:
LogManager(name)
return LogManager.__instance.log

def __init__(self, name=None):
if LogManager.__instance is not None:
raise Exception("This class is a singleton!")
else:
LogManager.__instance = self
self.log = logging.getLogger(name if name else __name__)
self.log.setLevel(logging.DEBUG)

self.log.handlers.clear() # Clear existing handlers

elif not LogManager.__instance.log.handlers:
# Add a new handler only if no handlers are present
formatter = colorlog.ColoredFormatter(
"%(asctime)s [%(name)s] %(log_color)s%(levelname)s%(reset)s - %(message)s",
log_colors={
Expand All @@ -33,7 +23,34 @@ def __init__(self, name=None):
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self.log.addHandler(handler)
LogManager.__instance.log.addHandler(handler)

return LogManager.__instance.log

def __init__(self, name=None):
if LogManager.__instance is not None:
raise Exception("This class is a singleton!")
else:
LogManager.__instance = self
self.log = logging.getLogger(name if name else __name__)

self.log.setLevel(logging.DEBUG)

# Only set up handlers if they are not already configured
if not self.log.handlers:
formatter = colorlog.ColoredFormatter(
"%(asctime)s [%(name)s] %(log_color)s%(levelname)s%(reset)s - %(message)s",
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self.log.addHandler(handler)

def set_level(self, level):
self.log.setLevel(level)
Loading

0 comments on commit 45090d2

Please sign in to comment.