|
21 | 21 | import torch
|
22 | 22 | import os, sys
|
23 | 23 | import logging
|
| 24 | +from fuxictr.pytorch.layers import FeatureEmbeddingDict |
24 | 25 | from fuxictr.metrics import evaluate_metrics
|
25 | 26 | from fuxictr.pytorch.torch_utils import get_device, get_optimizer, get_loss, get_regularizer
|
26 | 27 | from fuxictr.utils import Monitor, not_in_whitelist
|
@@ -65,23 +66,24 @@ def compile(self, optimizer, loss, lr):
|
65 | 66 | self.loss_fn = get_loss(loss)
|
66 | 67 |
|
67 | 68 | def regularization_loss(self):
|
68 |
| - reg_loss = 0 |
| 69 | + reg_term = 0 |
69 | 70 | if self._embedding_regularizer or self._net_regularizer:
|
70 | 71 | emb_reg = get_regularizer(self._embedding_regularizer)
|
71 | 72 | net_reg = get_regularizer(self._net_regularizer)
|
72 |
| - for _, module in self.named_modules(): |
73 |
| - for p_name, param in module.named_parameters(): |
74 |
| - if param.requires_grad: |
75 |
| - if p_name in ["weight", "bias"]: |
76 |
| - if type(module) == nn.Embedding: |
77 |
| - if self._embedding_regularizer: |
78 |
| - for emb_p, emb_lambda in emb_reg: |
79 |
| - reg_loss += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p |
80 |
| - else: |
81 |
| - if self._net_regularizer: |
82 |
| - for net_p, net_lambda in net_reg: |
83 |
| - reg_loss += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p |
84 |
| - return reg_loss |
| 73 | + emb_params = set() |
| 74 | + for m_name, module in self.named_modules(): |
| 75 | + if type(module) == FeatureEmbeddingDict: |
| 76 | + for p_name, param in module.named_parameters(): |
| 77 | + if param.requires_grad: |
| 78 | + emb_params.add(".".join([m_name, p_name])) |
| 79 | + for emb_p, emb_lambda in emb_reg: |
| 80 | + reg_term += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p |
| 81 | + for name, param in self.named_parameters(): |
| 82 | + if param.requires_grad: |
| 83 | + if name not in emb_params: |
| 84 | + for net_p, net_lambda in net_reg: |
| 85 | + reg_term += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p |
| 86 | + return reg_term |
85 | 87 |
|
86 | 88 | def add_loss(self, return_dict, y_true):
|
87 | 89 | loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
|
|
0 commit comments