From a6b4c0ff2c9427fcdda4a98827376855a24a363f Mon Sep 17 00:00:00 2001 From: bojunliu0818 Date: Thu, 20 Jun 2024 22:55:53 -0500 Subject: [PATCH] add descriptions of functions --- tsdart/loss.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tsdart/loss.py b/tsdart/loss.py index 498a3d2..050a52f 100644 --- a/tsdart/loss.py +++ b/tsdart/loss.py @@ -31,6 +31,18 @@ def __init__(self, epsilon=1e-6, mode='regularize', symmetrized=False): self._symmetrized = symmetrized def forward(self, data): + """ Compute VAMP2 loss. + + Parameters + ---------- + data : tuple + Softmax probabilities of batch of transition pairs. + + Returns + ------- + VAMP2 loss + """ + assert len(data) == 2 koopman = estimate_koopman_matrix(data[0], data[1], epsilon=self._epsilon, mode=self._mode, symmetrized=self._symmetrized) @@ -86,6 +98,21 @@ def __init__(self, feat_dim, n_states, device, proto_update_factor=0.5, scaling_ self.scaling_temperature = scaling_temperature def forward(self, features, labels): + """ Compute dispersion loss. + + Parameters + ---------- + features : torch.Tensor + Hyperspherical embeddings of a batch of data. + + labels : torch.Tensor + Metastable states of a batch of data. + + Returns + ------- + loss : torch.Tensor + Dispersion loss + """ prototypes = self.prototypes.to(device=self.device) for i in range(len(labels)): @@ -143,6 +170,22 @@ def __init__(self, n_states, device, scaling_temperature=0.1): self.scaling_temperature = scaling_temperature def forward(self, features, labels): + """ Compute dispersion loss. + + Parameters + ---------- + features : torch.Tensor + Hyperspherical embeddings of a batch of data. + + labels : torch.Tensor + Metastable states of a batch of data. + + Returns + ------- + prototypes : torch.Tensor + State center vectors of shape [n_states, feat_dim]. + """ + with torch.no_grad(): proxy_labels = torch.arange(0, self.n_states).to(device=self.device) labels = labels.contiguous().view(-1, 1)