-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathmodules.py
More file actions
142 lines (106 loc) · 4.47 KB
/
modules.py
File metadata and controls
142 lines (106 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, spectral_norm
from utils import init_weights, get_padding
import numpy as np
from scipy import signal as sig
class CoMBD(torch.nn.Module):
def __init__(self, filters, kernels, groups, strides, use_spectral_norm=False):
super(CoMBD, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList()
init_channel = 1
for i, (f, k, g, s) in enumerate(zip(filters, kernels, groups, strides)):
self.convs.append(norm_f(Conv1d(init_channel, f, k, s, padding=get_padding(k, 1), groups=g)))
init_channel = f
self.conv_post = norm_f(Conv1d(filters[-1], 1, 3, 1, padding=get_padding(3, 1)))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
#fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MDC(torch.nn.Module):
def __init__(self, in_channel, channel, kernel, stride, dilations, use_spectral_norm=False):
super(MDC, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = torch.nn.ModuleList()
self.num_dilations = len(dilations)
for d in dilations:
self.convs.append(norm_f(Conv1d(in_channel, channel, kernel, stride=1, padding=get_padding(kernel, d),
dilation=d)))
self.conv_out = norm_f(Conv1d(channel, channel, 3, stride=stride, padding=get_padding(3, 1)))
def forward(self, x):
xs = None
for l in self.convs:
if xs is None:
xs = l(x)
else:
xs += l(x)
x = xs / self.num_dilations
x = self.conv_out(x)
x = F.leaky_relu(x, 0.1)
return x
class SubBandDiscriminator(torch.nn.Module):
def __init__(self, init_channel, channels, kernel, strides, dilations, use_spectral_norm=False):
super(SubBandDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.mdcs = torch.nn.ModuleList()
for c, s, d in zip(channels, strides, dilations):
self.mdcs.append(MDC(init_channel, c, kernel, s, d))
init_channel = c
self.conv_post = norm_f(Conv1d(init_channel, 1, 3, padding=get_padding(3, 1)))
def forward(self, x):
fmap = []
for l in self.mdcs:
x = l(x)
fmap.append(x)
x = self.conv_post(x)
#fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
# adapted from
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
class PQMF(torch.nn.Module):
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
super(PQMF, self).__init__()
self.N = N
self.taps = taps
self.cutoff = cutoff
self.beta = beta
QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta))
H = np.zeros((N, len(QMF)))
G = np.zeros((N, len(QMF)))
for k in range(N):
constant_factor = (2 * k + 1) * (np.pi /
(2 * N)) * (np.arange(taps + 1) -
((taps - 1) / 2)) # TODO: (taps - 1) -> taps
phase = (-1)**k * np.pi / 4
H[k] = 2 * QMF * np.cos(constant_factor + phase)
G[k] = 2 * QMF * np.cos(constant_factor - phase)
H = torch.from_numpy(H[:, None, :]).float()
G = torch.from_numpy(G[None, :, :]).float()
self.register_buffer("H", H)
self.register_buffer("G", G)
updown_filter = torch.zeros((N, N, N)).float()
for k in range(N):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.N = N
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def forward(self, x):
return self.analysis(x)
def analysis(self, x):
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
def synthesis(self, x):
x = F.conv_transpose1d(x,
self.updown_filter * self.N,
stride=self.N)
x = F.conv1d(x, self.G, padding=self.taps // 2)
return x