wetdog commited on
Commit
4876346
1 Parent(s): c17cc3e

Add mosanet gradio demo

Browse files
Files changed (4) hide show
  1. Dockerfile +48 -0
  2. app.py +103 -0
  3. modules.py +152 -0
  4. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10.12-slim
3
+
4
+ # Install required packages for building eSpeak and general utilities
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ autoconf \
8
+ automake \
9
+ libtool \
10
+ pkg-config \
11
+ git \
12
+ cmake \
13
+ ffmpeg \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+
17
+ RUN pip install --upgrade pip
18
+
19
+ RUN mkdir -p cache && chmod 777 cache
20
+
21
+ RUN useradd -m -u 1000 user
22
+
23
+ USER user
24
+
25
+ ENV HOME=/home/user \
26
+ PATH=/home/user/.local/bin:$PATH
27
+
28
+ # Set the working directory to the user's home directory
29
+ WORKDIR $HOME/app
30
+ # Onnx install
31
+
32
+ COPY --chown=user requirements.txt $HOME/app/
33
+
34
+ RUN pip install -r requirements.txt
35
+
36
+
37
+ COPY --chown=user . $HOME/app/
38
+
39
+ # Fix ownership issues
40
+ USER root
41
+ RUN chown -R user:user $HOME/app
42
+ USER user
43
+
44
+ EXPOSE 7860
45
+
46
+ CMD ["python3", "-u", "app.py"]
47
+
48
+
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoFeatureExtractor, WhisperModel, AutoModelForSpeechSeq2Seq
4
+ import numpy as np
5
+ import torchaudio
6
+ import librosa
7
+
8
+ import gradio as gr
9
+ from modules import load_audio, MosPredictor, denorm
10
+
11
+
12
+ mos_checkpoint = "ckpt_mosa_net_plus"
13
+
14
+ print('Loading MOSANET+ checkpoint...')
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
+ model = MosPredictor().to(device)
18
+ model.eval()
19
+ model.load_state_dict(torch.load(mos_checkpoint, map_location=device))
20
+
21
+ print('Loading Whisper checkpoint...')
22
+ feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
23
+ #model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3")
24
+ model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
25
+ model_asli = model_asli.to(device)
26
+
27
+
28
+ def predict_mos(wavefile:str):
29
+
30
+ print('Starting prediction...')
31
+ # STFT
32
+ wav = torchaudio.load(wavefile)[0]
33
+ lps = torch.from_numpy(np.expand_dims(np.abs(librosa.stft(wav[0].detach().numpy(), n_fft = 512, hop_length=256,win_length=512)).T, axis=0))
34
+ lps = lps.unsqueeze(1)
35
+
36
+ # Whisper Feature
37
+ audio = load_audio(wavefile)
38
+ inputs = feature_extractor(audio, return_tensors="pt")
39
+ input_features = inputs.input_features
40
+ input_features = input_features.to(device)
41
+
42
+ with torch.no_grad():
43
+ decoder_input_ids = torch.tensor([[1, 1]]) * model_asli.config.decoder_start_token_id
44
+ decoder_input_ids = decoder_input_ids.to(device)
45
+ last_hidden_state = model_asli(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state
46
+ whisper_feat = last_hidden_state
47
+
48
+ print('Model features shapes...')
49
+ print(whisper_feat.shape)
50
+ print(wav.shape)
51
+ print(lps.shape)
52
+
53
+ # prediction
54
+ wav = wav.to(device)
55
+ lps = lps.to(device)
56
+ Quality_1, Intell_1, frame1, frame2 = model(wav ,lps, whisper_feat)
57
+ quality_pred = Quality_1.cpu().detach().numpy()[0]
58
+ intell_pred = Intell_1.cpu().detach().numpy()[0]
59
+
60
+ print("predictions")
61
+ qa_text = f"Quality: {denorm(quality_pred)[0]:.2f} Inteligibility: {intell_pred[0]:.2f}"
62
+ print(qa_text)
63
+ return qa_text
64
+
65
+
66
+ title = """
67
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
68
+ <div
69
+ style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
70
+ > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
71
+ MOSA-Net Whisper features
72
+ </h1> </div>
73
+ </div>
74
+ """
75
+
76
+ description = """
77
+ This is a demo of [MOSA-Net+](https://github.com/dhimasryan/MOSA-Net-Cross-Domain/tree/main/MOSA_Net%2B),
78
+ an enhanced version of the multi-objective speech assessment model MOSA-Net, by leveraging the acoustic features from Whisper, a large-scaled weakly supervised model.
79
+ MOSA-Net+ was tested in the noisy-and-enhanced track of the VoiceMOS Challenge 2023, where it obtained the top-ranked performance among nine systems [full paper](https://arxiv.org/abs/2309.12766)
80
+ """
81
+
82
+ article = """
83
+ If the model contributes to your research please cite the following work:
84
+
85
+ R. E. Zezario, S. -W. Fu, F. Chen, C. -S. Fuh, H. -M. Wang and Y. Tsao, "Deep Learning-Based Non-Intrusive Multi-Objective Speech Assessment Model With Cross-Domain Features," in IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 54-70, 2023, doi: 10.1109/TASLP.2022.3205757.
86
+
87
+ R. E. Zezario, Y.-W. Chen, S.-W. Fu, Y. Tsao, H.-M. Wang, C.-S. Fuh, "A Study on Incorporating Whisper for Robust Speech Assessment," IEEE ICME 2024, July 2024, (Top Performance on the Track 3 - VoiceMOS Challenge 2023)"
88
+
89
+ demo contributed by [@wetdog](https://github.com/wetdog)
90
+ """
91
+ demo = gr.Blocks()
92
+ with demo:
93
+ gr.Markdown(title)
94
+ gr.Markdown(description)
95
+ gr.Interface(
96
+ fn=predict_mos,
97
+ inputs=gr.Audio(type='filepath'),
98
+ outputs="text",
99
+ allow_flagging=False,)
100
+ gr.Markdown(article)
101
+
102
+ demo.queue(max_size=10)
103
+ demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)
modules.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import numpy as np
5
+ from transformers import AutoFeatureExtractor, WhisperModel
6
+
7
+ import torchaudio
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import speechbrain
12
+ import librosa
13
+
14
+ from subprocess import CalledProcessError, run
15
+
16
+ #openai whispers load audio
17
+ SAMPLE_RATE=16000
18
+ def denorm(input_x):
19
+ input_x = input_x*(5-0) + 0
20
+ return input_x
21
+
22
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
23
+ """
24
+ Open an audio file and read as mono waveform, resampling as necessary
25
+
26
+ Parameters
27
+ ----------
28
+ file: str
29
+ The audio file to open
30
+
31
+ sr: int
32
+ The sample rate to resample the audio if necessary
33
+
34
+ Returns
35
+ -------
36
+ A NumPy array containing the audio waveform, in float32 dtype.
37
+ """
38
+
39
+ # This launches a subprocess to decode audio while down-mixing
40
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
41
+ # fmt: off
42
+ cmd = [
43
+ "ffmpeg",
44
+ "-nostdin",
45
+ "-threads", "0",
46
+ "-i", file,
47
+ "-f", "s16le",
48
+ "-ac", "1",
49
+ "-acodec", "pcm_s16le",
50
+ "-ar", str(sr),
51
+ "-"
52
+ ]
53
+ # fmt: on
54
+ try:
55
+ out = run(cmd, capture_output=True, check=True).stdout
56
+ except CalledProcessError as e:
57
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
58
+
59
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
60
+
61
+ class MosPredictor(nn.Module):
62
+
63
+ def __init__(self):
64
+ super().__init__()
65
+
66
+ self.mean_net_conv = nn.Sequential(
67
+ nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = (3,3), padding = (1,1)),
68
+ nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = (3,3), padding = (1,1)),
69
+ nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
70
+ nn.Dropout(0.3),
71
+ nn.BatchNorm2d(16),
72
+ nn.ReLU(),
73
+ nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3,3), padding = (1,1)),
74
+ nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), padding = (1,1)),
75
+ nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
76
+ nn.Dropout(0.3),
77
+ nn.BatchNorm2d(32),
78
+ nn.ReLU(),
79
+ nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3,3), padding = (1,1)),
80
+ nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = (1,1)),
81
+ nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
82
+ nn.Dropout(0.3),
83
+ nn.BatchNorm2d(64),
84
+ nn.ReLU(),
85
+ nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (3,3), padding = (1,1)),
86
+ nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), padding = (1,1)),
87
+ nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), padding = (1,1), stride=(1,3)),
88
+ nn.Dropout(0.3),
89
+ nn.BatchNorm2d(128),
90
+ nn.ReLU())
91
+
92
+ self.relu_ = nn.ReLU()
93
+ self.sigmoid_ = nn.Sigmoid()
94
+
95
+ self.ssl_features = 1280
96
+ self.dim_layer = nn.Linear(self.ssl_features, 512)
97
+
98
+ self.mean_net_rnn = nn.LSTM(input_size = 512, hidden_size = 128, num_layers = 1, batch_first = True, bidirectional = True)
99
+ self.mean_net_dnn = nn.Sequential(
100
+ nn.Linear(256, 128),
101
+ nn.ReLU(),
102
+ nn.Dropout(0.3),
103
+ )
104
+
105
+ self.sinc = speechbrain.nnet.CNN.SincConv(in_channels=1, out_channels=257, kernel_size=251, stride=256, sample_rate=16000)
106
+ self.att_output_layer_quality = nn.MultiheadAttention(128, num_heads=8)
107
+ self.output_layer_quality = nn.Linear(128, 1)
108
+ self.qualaverage_score = nn.AdaptiveAvgPool1d(1)
109
+
110
+ self.att_output_layer_intell = nn.MultiheadAttention(128, num_heads=8)
111
+ self.output_layer_intell = nn.Linear(128, 1)
112
+ self.intellaverage_score = nn.AdaptiveAvgPool1d(1)
113
+
114
+ self.att_output_layer_stoi= nn.MultiheadAttention(128, num_heads=8)
115
+ self.output_layer_stoi = nn.Linear(128, 1)
116
+ self.stoiaverage_score = nn.AdaptiveAvgPool1d(1)
117
+
118
+ def new_method(self):
119
+ self.sin_conv
120
+
121
+ def forward(self, wav, lps, whisper):
122
+ #SSL Features
123
+ wav_ = wav.squeeze(1) ## [batches, audio_len]
124
+ ssl_feat_red = self.dim_layer(whisper.squeeze(1))
125
+ ssl_feat_red = self.relu_(ssl_feat_red)
126
+
127
+ #PS Features
128
+ sinc_feat=self.sinc(wav.squeeze(1))
129
+ unsq_sinc = torch.unsqueeze(sinc_feat, axis=1)
130
+ concat_lps_sinc = torch.cat((lps,unsq_sinc), axis=2)
131
+ cnn_out = self.mean_net_conv(concat_lps_sinc)
132
+ batch = concat_lps_sinc.shape[0]
133
+ time = concat_lps_sinc.shape[2]
134
+ re_cnn = cnn_out.view((batch, time, 512))
135
+
136
+ concat_feat = torch.cat((re_cnn,ssl_feat_red), axis=1)
137
+ out_lstm, (h, c) = self.mean_net_rnn(concat_feat)
138
+ out_dense = self.mean_net_dnn(out_lstm) # (batch, seq, 1)
139
+
140
+ quality_att, _ = self.att_output_layer_quality (out_dense, out_dense, out_dense)
141
+ frame_quality = self.output_layer_quality(quality_att)
142
+ frame_quality = self.sigmoid_(frame_quality)
143
+ quality_utt = self.qualaverage_score(frame_quality.permute(0,2,1))
144
+
145
+ int_att, _ = self.att_output_layer_intell (out_dense, out_dense, out_dense)
146
+ frame_int = self.output_layer_intell(int_att)
147
+ frame_int = self.sigmoid_(frame_int)
148
+ int_utt = self.intellaverage_score(frame_int.permute(0,2,1))
149
+
150
+
151
+ return quality_utt.squeeze(1), int_utt.squeeze(1), frame_quality.squeeze(2), frame_int.squeeze(2)
152
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ speechbrain
3
+ librosa
4
+ gradio
5
+ accelerate