diff --git a/tsdart/loss.py b/tsdart/loss.py index 498a3d2..050a52f 100644 --- a/tsdart/loss.py +++ b/tsdart/loss.py @@ -31,6 +31,18 @@ def __init__(self, epsilon=1e-6, mode='regularize', symmetrized=False): self._symmetrized = symmetrized def forward(self, data): + """ Compute VAMP2 loss. + + Parameters + ---------- + data : tuple + Softmax probabilities of batch of transition pairs. + + Returns + ------- + VAMP2 loss + """ + assert len(data) == 2 koopman = estimate_koopman_matrix(data[0], data[1], epsilon=self._epsilon, mode=self._mode, symmetrized=self._symmetrized) @@ -86,6 +98,21 @@ def __init__(self, feat_dim, n_states, device, proto_update_factor=0.5, scaling_ self.scaling_temperature = scaling_temperature def forward(self, features, labels): + """ Compute dispersion loss. + + Parameters + ---------- + features : torch.Tensor + Hyperspherical embeddings of a batch of data. + + labels : torch.Tensor + Metastable states of a batch of data. + + Returns + ------- + loss : torch.Tensor + Dispersion loss + """ prototypes = self.prototypes.to(device=self.device) for i in range(len(labels)): @@ -143,6 +170,22 @@ def __init__(self, n_states, device, scaling_temperature=0.1): self.scaling_temperature = scaling_temperature def forward(self, features, labels): + """ Compute dispersion loss. + + Parameters + ---------- + features : torch.Tensor + Hyperspherical embeddings of a batch of data. + + labels : torch.Tensor + Metastable states of a batch of data. + + Returns + ------- + prototypes : torch.Tensor + State center vectors of shape [n_states, feat_dim]. + """ + with torch.no_grad(): proxy_labels = torch.arange(0, self.n_states).to(device=self.device) labels = labels.contiguous().view(-1, 1)