Skip to content

Commit

Permalink
refactor(wavernn): remove duplicate Stretch2d
Browse files Browse the repository at this point in the history
I checked that the implementations are the same
  • Loading branch information
eginhard committed Nov 22, 2024
1 parent b13e9dc commit 3a803a2
Showing 1 changed file with 1 addition and 13 deletions.
14 changes: 1 addition & 13 deletions TTS/vocoder/models/wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.layers.upsample import Stretch2d
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian

Expand Down Expand Up @@ -66,19 +67,6 @@ def forward(self, x):
return x


class Stretch2d(nn.Module):
def __init__(self, x_scale, y_scale):
super().__init__()
self.x_scale = x_scale
self.y_scale = y_scale

def forward(self, x):
b, c, h, w = x.size()
x = x.unsqueeze(-1).unsqueeze(3)
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
return x.view(b, c, h * self.y_scale, w * self.x_scale)


class UpsampleNetwork(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit 3a803a2

Please sign in to comment.