Skip to content

Commit

Permalink
Update penalty.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille authored Dec 16, 2024
1 parent 2c15a99 commit 43028e1
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions torchelie/loss/gan/penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def R1(model,
R1 regularizer from Which Training Methods for GANs do actually Converge?
( https://arxiv.org/abs/1801.04406 ).
It forces grad norm of 0 on real data.
Args:
model (function / nn.Module): the model
real (torch.Tensor): real images
Expand All @@ -78,6 +80,53 @@ def R1(model,
"""
return gradient_penalty(model, real, 0., amp_scaler)

def R2(model,
real: torch.Tensor,
fake: torch.Tensor,
amp_scaler=None) -> Tuple[torch.Tensor, float]:
"""
R2 regularizer from Which Training Methods for GANs do actually Converge?
( https://arxiv.org/abs/1801.04406 ).
It forces grad norm of 0 on fake data.
Args:
model (function / nn.Module): the model
real (torch.Tensor): real images
fake (torch.Tensor): unused. Here for interface consistency with other
penalties.
amp_scaler (torch.cuda.amp.GradScaler): if specified, will be used
for computing the loss in fp16. Otherwise, use model's and
data's dtype.
Returns:
A tuple (loss, gradient norm)
"""
return gradient_penalty(model, fake, 0., amp_scaler)

def R3(model,
real: torch.Tensor,
fake: torch.Tensor,
amp_scaler=None) -> Tuple[torch.Tensor, float]:
"""
R2 regularizer from The GAN is dead; long live the GAN! A Modern Baseline GAN
( https://openreview.net/pdf?id=OrtN9hPP7V ).
It forces grad norm of 0 on real and fake data.
Args:
model (function / nn.Module): the model
real (torch.Tensor): real images
fake (torch.Tensor): unused. Here for interface consistency with other
penalties.
amp_scaler (torch.cuda.amp.GradScaler): if specified, will be used
for computing the loss in fp16. Otherwise, use model's and
data's dtype.
Returns:
A tuple (loss, gradient norm)
"""
return gradient_penalty(model, torch.cat([real, fake], dim=0), 0., amp_scaler)

def gradient_penalty(model,
data: torch.Tensor,
Expand Down

0 comments on commit 43028e1

Please sign in to comment.