Skip to content

Commit

Permalink
Bring back old hinge loss implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Mar 12, 2024
1 parent 4cad94b commit 5f33f21
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions src/gnn_tracking/metrics/losses/metric_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from torch import Tensor as T
from torch.linalg import norm
from torch.nn.functional import relu
from torch_cluster import radius_graph

from gnn_tracking.metrics.losses import MultiLossFct, MultiLossFctReturn
Expand Down Expand Up @@ -175,3 +176,98 @@ def forward(
weight_dct=weights,
extra_metrics=extra,
)


@torch.jit.script
def _old_hinge_loss_components(
*,
x: T,
edge_index: T,
particle_id: T,
pt: T,
r_emb_hinge: float,
pt_thld: float,
p_attr: float,
p_rep: float,
) -> tuple[T, T]:
true_edge = (particle_id[edge_index[0]] == particle_id[edge_index[1]]) & (
particle_id[edge_index[0]] > 0
)
true_high_pt_edge = true_edge & (pt[edge_index[0]] > pt_thld)
dists = norm(x[edge_index[0]] - x[edge_index[1]], dim=-1)
normalization = true_high_pt_edge.sum() + 1e-8
return torch.sum(
torch.pow(dists[true_high_pt_edge], p_attr)
) / normalization, torch.sum(
relu(r_emb_hinge - torch.pow(dists[~true_edge], p_rep)) / normalization
)


class OldGraphConstructionHingeEmbeddingLoss(MultiLossFct, HyperparametersMixin):
# noinspection PyUnusedLocal
def __init__(
self,
*,
r_emb=1,
max_num_neighbors: int = 256,
attr_pt_thld: float = 0.9,
p_attr: float = 1,
p_rep: float = 1,
lw_repulsive: float = 1.0,
):
"""Loss for graph construction using metric learning.
Args:
r_emb: Radius for edge construction
max_num_neighbors: Maximum number of neighbors in radius graph building.
See https://github.com/rusty1s/pytorch_cluster#radius-graph
p_attr: Power for the attraction term (default 1: linear loss)
p_rep: Power for the repulsion term (default 1: linear loss)
"""
super().__init__()
self.save_hyperparameters()

def _build_graph(self, x: T, batch: T, true_edge_index: T, pt: T) -> T:
true_edge_mask = pt[true_edge_index[0]] > self.hparams.attr_pt_thld
near_edges = radius_graph(
x,
r=self.hparams.r_emb,
batch=batch,
loop=False,
max_num_neighbors=self.hparams.max_num_neighbors,
)
return torch.unique(
torch.cat([true_edge_index[:, true_edge_mask], near_edges], dim=-1), dim=-1
)

# noinspection PyUnusedLocal
def forward(
self, *, x: T, particle_id: T, batch: T, true_edge_index: T, pt: T, **kwargs
) -> dict[str, T]:
edge_index = self._build_graph(
x=x, batch=batch, true_edge_index=true_edge_index, pt=pt
)
attr, rep = _old_hinge_loss_components(
x=x,
edge_index=edge_index,
particle_id=particle_id,
r_emb_hinge=self.hparams.r_emb,
pt=pt,
pt_thld=self.hparams.attr_pt_thld,
p_attr=self.hparams.p_attr,
p_rep=self.hparams.p_rep,
)
losses = {
"attractive": attr,
"repulsive": rep,
}
weights: dict[str, float] = {
"attractive": 1.0,
"repulsive": self.hparams.lw_repulsive,
}
extra = {}
return MultiLossFctReturn(
loss_dct=losses,
weight_dct=weights,
extra_metrics=extra,
)

0 comments on commit 5f33f21

Please sign in to comment.