Skip to content

Commit

Permalink
Add hausdorff metric
Browse files Browse the repository at this point in the history
Signed-off-by: aapozd <[email protected]>
  • Loading branch information
aaletov committed May 17, 2024
1 parent 9ac7751 commit 7acae03
Showing 1 changed file with 63 additions and 5 deletions.
68 changes: 63 additions & 5 deletions open_pcc_metric/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,39 @@ def __init__(
n=0,
)

@staticmethod
def scale_rgb(cloud: o3d.geometry.PointCloud):
scaled_colors = np.apply_along_axis(
func1d=lambda c: 255 * c,
axis=1,
arr=cloud.colors
)
cloud.colors = o3d.utility.Vector3dVector(scaled_colors)

@staticmethod
def convert_cloud_to_yuv(cloud: o3d.geometry.PointCloud):
"""Helper function to convert RGB to YUV (BT.709 or YCoCg-R)
"""
# only rgb supported
transform = np.array([
[0.25, 0.5, 0.25],
[1, 0, -1],
[-0.5, 1, -0.5]
])
def converter(c: np.ndarray) -> np.ndarray:
yuv_c = np.matmul(transform, c)
# offset 2^8 Co and Cg
yuv_c[1] += 256
yuv_c[2] += 256
return yuv_c

converted_colors = np.apply_along_axis(
func1d=converter,
axis=1,
arr=cloud.colors
)
cloud.colors = o3d.utility.Vector3dVector(converted_colors)

@staticmethod
def get_neighbour(
point: np.ndarray,
Expand Down Expand Up @@ -84,7 +117,7 @@ class PrimaryMetric(AbstractMetric):
def calculate(self, cloud_pair: CloudPair):
raise NotImplementedError("calculate is not implmented")

class OrderedMetric(PrimaryMetric):
class DirectionalMetric(PrimaryMetric):
is_left: bool

def __init__(self, is_left: bool):
Expand Down Expand Up @@ -141,7 +174,7 @@ def calculate(self, metrics: typing.List[AbstractMetric]) -> bool:
self.value = boundary_metric.value[1]
return True

class GeoMSE(OrderedMetric):
class GeoMSE(DirectionalMetric):
label = "GeoMSE"

def calculate(self, cloud_pair: CloudPair):
Expand All @@ -153,7 +186,7 @@ def calculate(self, cloud_pair: CloudPair):
n = cloud_pair._origin_neigh_dists.shape[0]
self.value = sse / n

class GeoPSNR(SecondaryMetric, OrderedMetric):
class GeoPSNR(SecondaryMetric, DirectionalMetric):
label = "GeoPSNR"

def calculate(self, metrics: typing.List[AbstractMetric]) -> bool:
Expand All @@ -170,7 +203,7 @@ def calculate(self, metrics: typing.List[AbstractMetric]) -> bool:
self.value = 10 * np.log10(max_neigh_dist**2 / geo_mse.value)
return True

class ColorPSNR(SecondaryMetric, OrderedMetric):
class ColorPSNR(SecondaryMetric, DirectionalMetric):
label = "ColorPSNR"

def calculate(self, metrics: typing.List[AbstractMetric]) -> bool:
Expand All @@ -182,7 +215,7 @@ def calculate(self, metrics: typing.List[AbstractMetric]) -> bool:
self.value = 10 * np.log10(peak**2 / geo_mse.value)
return True

class ColorMSE(OrderedMetric):
class ColorMSE(DirectionalMetric):
label = "ColorMSE"

def calculate(self, cloud_pair: CloudPair):
Expand All @@ -193,6 +226,27 @@ def calculate(self, cloud_pair: CloudPair):
diff = np.subtract(cloud_pair.reconst_cloud.colors, cloud_pair._reconst_neigh_cloud.colors)
self.value = np.mean(diff**2, axis=0)

class GeoHausdorffDistance(DirectionalMetric):
label = "GeoHausdorfDistance"

def calculate(self, cloud_pair: CloudPair):
if self.is_left:
self.value = np.max(cloud_pair._origin_neigh_dists, axis=0)
else:
self.value = np.max(cloud_pair._reconst_neigh_dists, axis=0)

class ColorHausdorffDistance(DirectionalMetric):
label = "ColorHausdorffDistance"

def calculate(self, cloud_pair: CloudPair):
diff = None
if self.is_left:
diff = np.subtract(cloud_pair.origin_cloud.colors, cloud_pair._origin_neigh_cloud.colors)
else:
diff = np.subtract(cloud_pair.reconst_cloud.colors, cloud_pair._reconst_neigh_cloud.colors)
rgb_scale = 255
self.value = np.max((rgb_scale * diff)**2, axis=0)

class SymmetricMetric(SecondaryMetric):
is_proportional: bool
target_label: str
Expand Down Expand Up @@ -267,6 +321,10 @@ def calculate_from_files(
GeoMSE(is_left=False),
ColorMSE(is_left=True),
ColorMSE(is_left=False),
GeoHausdorffDistance(is_left=True),
GeoHausdorffDistance(is_left=False),
ColorHausdorffDistance(is_left=True),
ColorHausdorffDistance(is_left=False),
]
secondary_metrics = [
MinSqrtDistance(),
Expand Down

0 comments on commit 7acae03

Please sign in to comment.