Skip to content

Commit

Permalink
Merge pull request #14 from ULAS-HiPR/develop
Browse files Browse the repository at this point in the history
tests and strengthened conversions
  • Loading branch information
bxrne authored Feb 28, 2024
2 parents f8b91b9 + 020aa97 commit 8c0a06f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
52 changes: 52 additions & 0 deletions agrinet/tests/testMetrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import sys
import os

# fixes "ModuleNotFoundError: No module named 'utils'"
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# flake8: noqa
import unittest
import numpy as np
import pandas as pd
from utils.Metrics import CGANMetrics


class TestCGANMetrics(unittest.TestCase):
def setUp(self):
self.cgan_metrics = CGANMetrics()

def tearDown(self):
del self.cgan_metrics

def test_psnr(self):
x = np.random.randint(0, 255, size=(256, 256, 3)).astype(np.float32)
y = np.random.randint(0, 255, size=(256, 256, 3)).astype(np.float32)
psnr_score = self.cgan_metrics.psnr(x, y)
self.assertIsInstance(psnr_score, float)

def test_mmd(self):
x = np.random.randn(16, 256, 256, 3).astype(np.float32)
y = np.random.randn(16, 256, 256, 3).astype(np.float32)
mmd_value = self.cgan_metrics.mmd(x, y)
self.assertIsInstance(mmd_value, float)

def test_update(self):
disc_out = np.random.randn(16, 256, 256, 3).astype(np.float32)
gen_out = np.random.randn(16, 256, 256, 3).astype(np.float32)
truth_in = np.random.randn(16, 256, 256, 3).astype(np.float32)
truth_out = np.random.randn(16, 256, 256, 3).astype(np.float32)
self.cgan_metrics.update(disc_out, gen_out, truth_in, truth_out)
self.assertEqual(len(self.cgan_metrics.results), 1)

def test_get_metric(self):
self.cgan_metrics.results = pd.DataFrame(
{"MMD": [0.5, 0.6, 0.7], "PSNR": [20, 25, 30]}
)
mmd_mean = self.cgan_metrics.get_metric("MMD")
psnr_mean = self.cgan_metrics.get_metric("PSNR")
self.assertAlmostEqual(mmd_mean, 0.6, places=2)
self.assertAlmostEqual(psnr_mean, 25, places=2)


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions agrinet/utils/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def update(self, disc_out, gen_out, truth_in, truth_out):

def psnr(self, x, y):
"""Compute the Peak Signal to Noise Ratio between two images"""
x = x.numpy()
y = y.numpy()
x = x.numpy() if hasattr(x, "numpy") else x
y = y.numpy() if hasattr(y, "numpy") else y
mse = np.mean((x - y) ** 2)
return 20 * np.log10(255) - 10 * np.log10(mse)

def mmd(self, x, y):
"""Compute the Maximum Mean Discrepancy between two sets of samples"""
x = x.numpy()
y = y.numpy()
x = x.numpy() if hasattr(x, "numpy") else x
y = y.numpy() if hasattr(y, "numpy") else y
x_flat = np.reshape(x, (x.shape[0], -1))
y_flat = np.reshape(y, (y.shape[0], -1))

Expand Down

0 comments on commit 8c0a06f

Please sign in to comment.