Plachta commited on
Commit
a50ee15
β€’
1 Parent(s): a909977

Upload 5 files

Browse files
Files changed (4) hide show
  1. gradient_reversal.py +35 -0
  2. losses.py +309 -0
  3. meldataset.py +131 -0
  4. optimizers.py +108 -0
gradient_reversal.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.autograd import Function
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class GradientReversal(Function):
12
+ @staticmethod
13
+ def forward(ctx, x, alpha):
14
+ ctx.save_for_backward(x, alpha)
15
+ return x
16
+
17
+ @staticmethod
18
+ def backward(ctx, grad_output):
19
+ grad_input = None
20
+ _, alpha = ctx.saved_tensors
21
+ if ctx.needs_input_grad[0]:
22
+ grad_input = -alpha * grad_output
23
+ return grad_input, None
24
+
25
+
26
+ revgrad = GradientReversal.apply
27
+
28
+
29
+ class GradientReversal(nn.Module):
30
+ def __init__(self, alpha):
31
+ super().__init__()
32
+ self.alpha = torch.tensor(alpha, requires_grad=False)
33
+
34
+ def forward(self, x):
35
+ return revgrad(x, self.alpha)
losses.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchaudio.transforms import MelSpectrogram
4
+
5
+
6
+ def adversarial_g_loss(y_disc_gen):
7
+ """Hinge loss"""
8
+ loss = 0.0
9
+ for i in range(len(y_disc_gen)):
10
+ stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
11
+ loss += stft_loss
12
+ return loss / len(y_disc_gen)
13
+
14
+
15
+ def feature_loss(fmap_r, fmap_gen):
16
+ loss = 0.0
17
+ for i in range(len(fmap_r)):
18
+ for j in range(len(fmap_r[i])):
19
+ stft_loss = ((fmap_r[i][j] - fmap_gen[i][j]).abs() /
20
+ (fmap_r[i][j].abs().mean())).mean()
21
+ loss += stft_loss
22
+ return loss / (len(fmap_r) * len(fmap_r[0]))
23
+
24
+
25
+ def sim_loss(y_disc_r, y_disc_gen):
26
+ loss = 0.0
27
+ for i in range(len(y_disc_r)):
28
+ loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
29
+ return loss / len(y_disc_r)
30
+
31
+ # def sisnr_loss(x, s, eps=1e-8):
32
+ # """
33
+ # calculate training loss
34
+ # input:
35
+ # x: separated signal, N x S tensor, estimate value
36
+ # s: reference signal, N x S tensor, True value
37
+ # Return:
38
+ # sisnr: N tensor
39
+ # """
40
+ # if x.shape != s.shape:
41
+ # if x.shape[-1] > s.shape[-1]:
42
+ # x = x[:, :s.shape[-1]]
43
+ # else:
44
+ # s = s[:, :x.shape[-1]]
45
+ # def l2norm(mat, keepdim=False):
46
+ # return torch.norm(mat, dim=-1, keepdim=keepdim)
47
+ # if x.shape != s.shape:
48
+ # raise RuntimeError(
49
+ # "Dimention mismatch when calculate si-snr, {} vs {}".format(
50
+ # x.shape, s.shape))
51
+ # x_zm = x - torch.mean(x, dim=-1, keepdim=True)
52
+ # s_zm = s - torch.mean(s, dim=-1, keepdim=True)
53
+ # t = torch.sum(
54
+ # x_zm * s_zm, dim=-1,
55
+ # keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
56
+ # loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
57
+ # return torch.sum(loss) / x.shape[0]
58
+
59
+ LAMBDA_WAV = 100
60
+ LAMBDA_ADV = 1
61
+ LAMBDA_REC = 1
62
+ LAMBDA_COM = 1000
63
+ LAMBDA_FEAT = 1
64
+ discriminator_iter_start = 500
65
+ def reconstruction_loss(x, G_x, eps=1e-7):
66
+ # NOTE (lsx): hard-coded now
67
+ L = LAMBDA_WAV * F.mse_loss(x, G_x) # wav L1 loss
68
+ # loss_sisnr = sisnr_loss(G_x, x) #
69
+ # L += 0.01*loss_sisnr
70
+ # 2^6=64 -> 2^10=1024
71
+ # NOTE (lsx): add 2^11
72
+ for i in range(6, 12):
73
+ # for i in range(5, 12): # Encodec setting
74
+ s = 2**i
75
+ melspec = MelSpectrogram(
76
+ sample_rate=16000,
77
+ n_fft=max(s, 512),
78
+ win_length=s,
79
+ hop_length=s // 4,
80
+ n_mels=64,
81
+ wkwargs={"device": G_x.device}).to(G_x.device)
82
+ S_x = melspec(x)
83
+ S_G_x = melspec(G_x)
84
+ l1_loss = (S_x - S_G_x).abs().mean()
85
+ l2_loss = (((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2).mean(dim=-2)**0.5).mean()
86
+
87
+ alpha = (s / 2) ** 0.5
88
+ L += (l1_loss + alpha * l2_loss)
89
+ return L
90
+
91
+
92
+ def criterion_d(y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det, y_df_hat_r,
93
+ y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g,
94
+ fmap_s_r, fmap_s_g):
95
+ """Hinge Loss"""
96
+ loss = 0.0
97
+ loss1 = 0.0
98
+ loss2 = 0.0
99
+ loss3 = 0.0
100
+ for i in range(len(y_disc_r)):
101
+ loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[
102
+ i]).mean()
103
+ for i in range(len(y_df_hat_r)):
104
+ loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[
105
+ i]).mean()
106
+ for i in range(len(y_ds_hat_r)):
107
+ loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[
108
+ i]).mean()
109
+
110
+ loss = (loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 /
111
+ len(y_ds_hat_r)) / 3.0
112
+
113
+ return loss
114
+
115
+
116
+ def criterion_g(commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen,
117
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r,
118
+ y_ds_hat_g, fmap_s_r, fmap_s_g, args):
119
+ adv_g_loss = adversarial_g_loss(y_disc_gen)
120
+ feat_loss = (feature_loss(fmap_r, fmap_gen) + sim_loss(
121
+ y_disc_r, y_disc_gen) + feature_loss(fmap_f_r, fmap_f_g) + sim_loss(
122
+ y_df_hat_r, y_df_hat_g) + feature_loss(fmap_s_r, fmap_s_g) +
123
+ sim_loss(y_ds_hat_r, y_ds_hat_g)) / 3.0
124
+ rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
125
+ total_loss = args.LAMBDA_COM * commit_loss + args.LAMBDA_ADV * adv_g_loss + args.LAMBDA_FEAT * feat_loss + args.LAMBDA_REC * rec_loss
126
+ return total_loss, adv_g_loss, feat_loss, rec_loss
127
+
128
+
129
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
130
+ if global_step < threshold:
131
+ weight = value
132
+ return weight
133
+
134
+
135
+ def adopt_dis_weight(weight, global_step, threshold=0, value=0.):
136
+ # 0,3,6,9,13....θΏ™δΊ›ζ—Άι—΄ζ­₯οΌŒδΈζ›΄ζ–°dis
137
+ if global_step % 3 == 0:
138
+ weight = value
139
+ return weight
140
+
141
+
142
+ def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
143
+ if last_layer is not None:
144
+ nll_grads = torch.autograd.grad(
145
+ nll_loss, last_layer, retain_graph=True)[0]
146
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
147
+ else:
148
+ print('last_layer cannot be none')
149
+ assert 1 == 2
150
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
151
+ d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
152
+ d_weight = d_weight * args.LAMBDA_ADV
153
+ return d_weight
154
+
155
+ def loss_g(codebook_loss,
156
+ inputs,
157
+ reconstructions,
158
+ fmap_r,
159
+ fmap_gen,
160
+ y_disc_r,
161
+ y_disc_gen,
162
+ global_step,
163
+ y_df_hat_r,
164
+ y_df_hat_g,
165
+ y_ds_hat_r,
166
+ y_ds_hat_g,
167
+ fmap_f_r,
168
+ fmap_f_g,
169
+ fmap_s_r,
170
+ fmap_s_g,
171
+ last_layer=None,
172
+ is_training=True,
173
+ args=None):
174
+ """
175
+ args:
176
+ codebook_loss: commit loss.
177
+ inputs: ground-truth wav.
178
+ reconstructions: reconstructed wav.
179
+ fmap_r: real stft-D feature map.
180
+ fmap_gen: fake stft-D feature map.
181
+ y_disc_r: real stft-D logits.
182
+ y_disc_gen: fake stft-D logits.
183
+ global_step: global training step.
184
+ y_df_hat_r: real MPD logits.
185
+ y_df_hat_g: fake MPD logits.
186
+ y_ds_hat_r: real MSD logits.
187
+ y_ds_hat_g: fake MSD logits.
188
+ fmap_f_r: real MPD feature map.
189
+ fmap_f_g: fake MPD feature map.
190
+ fmap_s_r: real MSD feature map.
191
+ fmap_s_g: fake MSD feature map.
192
+ """
193
+ rec_loss = reconstruction_loss(inputs.contiguous(),
194
+ reconstructions.contiguous())
195
+ adv_g_loss = adversarial_g_loss(y_disc_gen)
196
+ adv_mpd_loss = adversarial_g_loss(y_df_hat_g)
197
+ adv_msd_loss = adversarial_g_loss(y_ds_hat_g)
198
+ adv_loss = (adv_g_loss + adv_mpd_loss + adv_msd_loss
199
+ ) / 3.0 # NOTE(lsx): need to divide by 3?
200
+ feat_loss = feature_loss(
201
+ fmap_r,
202
+ fmap_gen) #+ sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
203
+ feat_loss_mpd = feature_loss(fmap_f_r,
204
+ fmap_f_g) #+ sim_loss(y_df_hat_r, y_df_hat_g)
205
+ feat_loss_msd = feature_loss(fmap_s_r,
206
+ fmap_s_g) #+ sim_loss(y_ds_hat_r, y_ds_hat_g)
207
+ feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
208
+ d_weight = torch.tensor(1.0)
209
+ # try:
210
+ # d_weight = calculate_adaptive_weight(rec_loss, adv_g_loss, last_layer, args) # εŠ¨ζ€θ°ƒζ•΄ι‡ζž„ζŸε€±ε’Œε―ΉζŠ—ζŸε€±
211
+ # except RuntimeError:
212
+ # assert not is_training
213
+ # d_weight = torch.tensor(0.0)
214
+ disc_factor = adopt_weight(
215
+ LAMBDA_ADV, global_step, threshold=discriminator_iter_start)
216
+ if disc_factor == 0.:
217
+ fm_loss_wt = 0
218
+ else:
219
+ fm_loss_wt = LAMBDA_FEAT
220
+ #feat_factor = adopt_weight(args.LAMBDA_FEAT, global_step, threshold=args.discriminator_iter_start)
221
+ loss = rec_loss + d_weight * disc_factor * adv_loss + \
222
+ fm_loss_wt * feat_loss_tot + LAMBDA_COM * codebook_loss.mean()
223
+ return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
224
+
225
+
226
+ def loss_dis(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r,
227
+ y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r,
228
+ fmap_s_g, global_step):
229
+ disc_factor = adopt_weight(
230
+ LAMBDA_ADV, global_step, threshold=discriminator_iter_start)
231
+ d_loss = disc_factor * criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det,
232
+ fmap_gen_det, y_df_hat_r, y_df_hat_g,
233
+ fmap_f_r, fmap_f_g, y_ds_hat_r,
234
+ y_ds_hat_g, fmap_s_r, fmap_s_g)
235
+ return d_loss
236
+
237
+ class AttentionCTCLoss(torch.nn.Module):
238
+ def __init__(self, blank_logprob=-1):
239
+ super(AttentionCTCLoss, self).__init__()
240
+ self.log_softmax = torch.nn.LogSoftmax(dim=3)
241
+ self.blank_logprob = blank_logprob
242
+ self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
243
+
244
+ def forward(self, attn_logprob, in_lens, out_lens):
245
+ key_lens = in_lens
246
+ query_lens = out_lens
247
+ attn_logprob_padded = F.pad(
248
+ input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0),
249
+ value=self.blank_logprob)
250
+ cost_total = 0.0
251
+ for bid in range(attn_logprob.shape[0]):
252
+ target_seq = torch.arange(1, key_lens[bid]+1).unsqueeze(0)
253
+ curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
254
+ :query_lens[bid], :, :key_lens[bid]+1]
255
+ curr_logprob = self.log_softmax(curr_logprob[None])[0]
256
+ ctc_cost = self.CTCLoss(curr_logprob, target_seq,
257
+ input_lengths=query_lens[bid:bid+1],
258
+ target_lengths=key_lens[bid:bid+1])
259
+ cost_total += ctc_cost
260
+ cost = cost_total/attn_logprob.shape[0]
261
+ return cost
262
+
263
+
264
+ class FocalLoss(torch.nn.Module):
265
+
266
+ def __init__(self, gamma=0, eps=1e-7):
267
+ super(FocalLoss, self).__init__()
268
+ self.gamma = gamma
269
+ self.eps = eps
270
+ self.ce = torch.nn.CrossEntropyLoss()
271
+
272
+ def forward(self, input, target):
273
+ logp = self.ce(input, target)
274
+ p = torch.exp(-logp)
275
+ loss = (1 - p) ** self.gamma * logp
276
+ return loss.mean()
277
+
278
+ def feature_loss(fmap_r, fmap_g):
279
+ loss = 0
280
+ for dr, dg in zip(fmap_r, fmap_g):
281
+ for rl, gl in zip(dr, dg):
282
+ loss += torch.mean(torch.abs(rl - gl))
283
+
284
+ return loss * 2
285
+
286
+
287
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
288
+ loss = 0
289
+ r_losses = []
290
+ g_losses = []
291
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
292
+ r_loss = torch.mean((1 - dr) ** 2)
293
+ g_loss = torch.mean(dg ** 2)
294
+ loss += (r_loss + g_loss)
295
+ r_losses.append(r_loss.item())
296
+ g_losses.append(g_loss.item())
297
+
298
+ return loss, r_losses, g_losses
299
+
300
+
301
+ def generator_loss(disc_outputs):
302
+ loss = 0
303
+ gen_losses = []
304
+ for dg in disc_outputs:
305
+ l = torch.mean((1 - dg) ** 2)
306
+ gen_losses.append(l)
307
+ loss += l
308
+
309
+ return loss, gen_losses
meldataset.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ import random
6
+ import numpy as np
7
+ import random
8
+ import soundfile as sf
9
+ import librosa
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import torchaudio
15
+ from torch.utils.data import DataLoader
16
+
17
+ import math
18
+
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logger.setLevel(logging.DEBUG)
23
+ from torch.utils.data.distributed import DistributedSampler
24
+
25
+
26
+ np.random.seed(114514)
27
+ random.seed(114514)
28
+ SPECT_PARAMS = {
29
+ "n_fft": 2048,
30
+ "win_length": 1200,
31
+ "hop_length": 300,
32
+ }
33
+ MEL_PARAMS = {
34
+ "n_mels": 80,
35
+ }
36
+
37
+ to_mel = torchaudio.transforms.MelSpectrogram(
38
+ n_mels=MEL_PARAMS['n_mels'], **SPECT_PARAMS)
39
+ mean, std = -4, 4
40
+
41
+
42
+ def preprocess(wave):
43
+ # wave = wave.unsqueeze(0)
44
+ wave_tensor = torch.from_numpy(wave).float()
45
+ mel_tensor = to_mel(wave_tensor)
46
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
47
+ return mel_tensor
48
+
49
+
50
+ class PseudoDataset(torch.utils.data.Dataset):
51
+ def __init__(self,
52
+ list_path,
53
+ sr=24000,
54
+ range=(1, 30), # length of the audio duration in seconds
55
+ ):
56
+
57
+ self.data_list = [] # read your list path here
58
+ self.sr = sr
59
+ self.duration_range = range
60
+
61
+ def __len__(self):
62
+ # return len(self.data_list)
63
+ return 100 # return a fixed number for testing
64
+
65
+ def __getitem__(self, idx):
66
+ # replace this with your own data loading
67
+ # wave, sr = librosa.load(self.data_list[idx], sr=self.sr)
68
+ wave = np.random.randn(self.sr * random.randint(*self.duration_range)).clamp(-1, 1)
69
+ mel = preprocess(wave)
70
+ return wave, mel
71
+
72
+
73
+ def collate(batch):
74
+ # batch[0] = wave, mel, text, f0, speakerid
75
+ batch_size = len(batch)
76
+
77
+ # sort by mel length
78
+ lengths = [b[1].shape[1] for b in batch]
79
+ batch_indexes = np.argsort(lengths)[::-1]
80
+ batch = [batch[bid] for bid in batch_indexes]
81
+
82
+ nmels = batch[0][1].size(0)
83
+ max_mel_length = max([b[1].shape[1] for b in batch])
84
+ max_wave_length = max([b[0].size(0) for b in batch])
85
+
86
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
87
+ waves = torch.zeros((batch_size, max_wave_length)).float()
88
+
89
+ mel_lengths = torch.zeros(batch_size).long()
90
+ wave_lengths = torch.zeros(batch_size).long()
91
+
92
+ for bid, (wave, mel) in enumerate(batch):
93
+ mel_size = mel.size(1)
94
+ mels[bid, :, :mel_size] = mel
95
+ waves[bid, : wave.size(0)] = wave
96
+ mel_lengths[bid] = mel_size
97
+ wave_lengths[bid] = wave.size(0)
98
+
99
+ return waves, mels, wave_lengths, mel_lengths
100
+
101
+
102
+ def build_dataloader(
103
+ rank=0,
104
+ world_size=1,
105
+ batch_size=32,
106
+ num_workers=0,
107
+ prefetch_factor=16,
108
+ ):
109
+ dataset = PseudoDataset() # replace this with your own dataset
110
+ collate_fn = collate
111
+ sampler = torch.utils.data.distributed.DistributedSampler(
112
+ dataset,
113
+ num_replicas=world_size,
114
+ rank=rank,
115
+ shuffle=True,
116
+ seed=114514,
117
+ )
118
+ data_loader = DataLoader(
119
+ dataset,
120
+ batch_size=batch_size,
121
+ sampler=sampler,
122
+ num_workers=num_workers,
123
+ drop_last=True,
124
+ collate_fn=collate_fn,
125
+ pin_memory=True,
126
+ prefetch_factor=prefetch_factor,
127
+ # shuffle=True,
128
+ )
129
+
130
+ return data_loader
131
+
optimizers.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+ import os, sys
3
+ import os.path as osp
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+ from functools import reduce
9
+ from torch.optim import AdamW
10
+
11
+ class MultiOptimizer:
12
+ def __init__(self, optimizers={}, schedulers={}):
13
+ self.optimizers = optimizers
14
+ self.schedulers = schedulers
15
+ self.keys = list(optimizers.keys())
16
+ self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])
17
+
18
+ def state_dict(self):
19
+ state_dicts = [(key, self.optimizers[key].state_dict())\
20
+ for key in self.keys]
21
+ return state_dicts
22
+
23
+ def scheduler_state_dict(self):
24
+ state_dicts = [(key, self.schedulers[key].state_dict())\
25
+ for key in self.keys]
26
+ return state_dicts
27
+
28
+ def load_state_dict(self, state_dict):
29
+ for key, val in state_dict:
30
+ try:
31
+ self.optimizers[key].load_state_dict(val)
32
+ except:
33
+ print("Unloaded %s" % key)
34
+
35
+ def load_scheduler_state_dict(self, state_dict):
36
+ for key, val in state_dict:
37
+ try:
38
+ self.schedulers[key].load_state_dict(val)
39
+ except:
40
+ print("Unloaded %s" % key)
41
+
42
+ def step(self, key=None, scaler=None):
43
+ keys = [key] if key is not None else self.keys
44
+ _ = [self._step(key, scaler) for key in keys]
45
+
46
+ def _step(self, key, scaler=None):
47
+ if scaler is not None:
48
+ scaler.step(self.optimizers[key])
49
+ scaler.update()
50
+ else:
51
+ self.optimizers[key].step()
52
+
53
+ def zero_grad(self, key=None):
54
+ if key is not None:
55
+ self.optimizers[key].zero_grad()
56
+ else:
57
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
58
+
59
+ def scheduler(self, *args, key=None):
60
+ if key is not None:
61
+ self.schedulers[key].step(*args)
62
+ else:
63
+ _ = [self.schedulers[key].step_batch(*args) for key in self.keys]
64
+
65
+ def define_scheduler(optimizer, params):
66
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params['gamma'])
67
+
68
+ return scheduler
69
+
70
+ from transformer_modules.optim import Eden, ScaledAdam
71
+
72
+ def build_optimizer(model_dict, scheduler_params_dict, lr, type='AdamW'):
73
+ optim = {}
74
+ for key, model in model_dict.items():
75
+ model_parameters = model.parameters()
76
+ parameters_names = []
77
+ parameters_names.append(
78
+ [
79
+ name_param_pair[0]
80
+ for name_param_pair in model.named_parameters()
81
+ ]
82
+ )
83
+ if type == 'ScaledAdam':
84
+ optim[key] = ScaledAdam(
85
+ model_parameters,
86
+ lr=lr,
87
+ betas=(0.9, 0.95),
88
+ clipping_scale=2.0,
89
+ parameters_names=parameters_names,
90
+ show_dominant_parameters=False,
91
+ clipping_update_period=1000,
92
+ )
93
+ elif type == 'AdamW':
94
+ optim[key] = AdamW(
95
+ model_parameters,
96
+ lr=lr,
97
+ betas=(0.9, 0.98),
98
+ eps=1e-9,
99
+ weight_decay=0.1,
100
+ )
101
+ else:
102
+ raise ValueError('Unknown optimizer type: %s' % type)
103
+
104
+ schedulers = dict([(key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
105
+ for key, opt in optim.items()])
106
+
107
+ multi_optim = MultiOptimizer(optim, schedulers)
108
+ return multi_optim