pnnbao-ump commited on
Commit
c99405f
·
verified ·
1 Parent(s): 0436f07

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +9 -10
  2. vieneu_tts.py +347 -0
requirements.txt CHANGED
@@ -1,10 +1,9 @@
1
- torchaudio
2
- transformers
3
- librosa
4
- soundfile
5
- numpy
6
- phonemizer
7
- neucodec
8
- lmdeploy
9
- pyyaml
10
- torch==2.8.0
 
1
+ gradio
2
+ spaces
3
+ torchaudio
4
+ transformers
5
+ librosa
6
+ soundfile
7
+ numpy
8
+ phonemizer
9
+ neucodec
 
vieneu_tts.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Generator
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from neucodec import NeuCodec, DistillNeuCodec
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from utils.phonemize_text import phonemize_text, phonemize_with_dict
9
+ import re
10
+
11
+ def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
12
+ # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
13
+ assert len(frames)
14
+ dtype = frames[0].dtype
15
+ shape = frames[0].shape[:-1]
16
+
17
+ total_size = 0
18
+ for i, frame in enumerate(frames):
19
+ frame_end = stride * i + frame.shape[-1]
20
+ total_size = max(total_size, frame_end)
21
+
22
+ sum_weight = np.zeros(total_size, dtype=dtype)
23
+ out = np.zeros(*shape, total_size, dtype=dtype)
24
+
25
+ offset: int = 0
26
+ for frame in frames:
27
+ frame_length = frame.shape[-1]
28
+ t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
29
+ weight = np.abs(0.5 - (t - 0.5))
30
+
31
+ out[..., offset : offset + frame_length] += weight * frame
32
+ sum_weight[offset : offset + frame_length] += weight
33
+ offset += stride
34
+ assert sum_weight.min() > 0
35
+ return out / sum_weight
36
+
37
+ class VieNeuTTS:
38
+ def __init__(
39
+ self,
40
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
41
+ backbone_device="cpu",
42
+ codec_repo="neuphonic/neucodec",
43
+ codec_device="cpu",
44
+ ):
45
+
46
+ # Constants
47
+ self.sample_rate = 24_000
48
+ self.max_context = 2048
49
+ self.hop_length = 480
50
+ self.streaming_overlap_frames = 1
51
+ self.streaming_frames_per_chunk = 25
52
+ self.streaming_lookforward = 5
53
+ self.streaming_lookback = 50
54
+ self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
55
+
56
+ # ggml & onnx flags
57
+ self._is_quantized_model = False
58
+ self._is_onnx_codec = False
59
+
60
+ # HF tokenizer
61
+ self.tokenizer = None
62
+
63
+ # Load models
64
+ self._load_backbone(backbone_repo, backbone_device)
65
+ self._load_codec(codec_repo, codec_device)
66
+
67
+ def _load_backbone(self, backbone_repo, backbone_device):
68
+ print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
69
+
70
+ if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
71
+ try:
72
+ from llama_cpp import Llama
73
+ except ImportError as e:
74
+ raise ImportError(
75
+ "Failed to import `llama_cpp`. "
76
+ "Please install it with:\n"
77
+ " pip install llama-cpp-python"
78
+ ) from e
79
+ self.backbone = Llama.from_pretrained(
80
+ repo_id=backbone_repo,
81
+ filename="*.gguf",
82
+ verbose=False,
83
+ n_gpu_layers=-1 if backbone_device == "gpu" else 0,
84
+ n_ctx=self.max_context,
85
+ mlock=True,
86
+ flash_attn=True if backbone_device == "gpu" else False,
87
+ )
88
+ self._is_quantized_model = True
89
+
90
+ else:
91
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
92
+ self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
93
+ torch.device(backbone_device)
94
+ )
95
+
96
+ def _load_codec(self, codec_repo, codec_device):
97
+ print(f"Loading codec from: {codec_repo} on {codec_device} ...")
98
+ match codec_repo:
99
+ case "neuphonic/neucodec":
100
+ self.codec = NeuCodec.from_pretrained(codec_repo)
101
+ self.codec.eval().to(codec_device)
102
+ case "neuphonic/distill-neucodec":
103
+ self.codec = DistillNeuCodec.from_pretrained(codec_repo)
104
+ self.codec.eval().to(codec_device)
105
+ case "neuphonic/neucodec-onnx-decoder":
106
+ if codec_device != "cpu":
107
+ raise ValueError("Onnx decoder only currently runs on CPU.")
108
+ try:
109
+ from neucodec import NeuCodecOnnxDecoder
110
+ except ImportError as e:
111
+ raise ImportError(
112
+ "Failed to import the onnx decoder."
113
+ " Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
114
+ ) from e
115
+ self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
116
+ self._is_onnx_codec = True
117
+ case _:
118
+ raise ValueError(f"Unsupported codec repository: {codec_repo}")
119
+
120
+ def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
121
+ """
122
+ Perform inference to generate speech from text using the TTS model and reference audio.
123
+
124
+ Args:
125
+ text (str): Input text to be converted to speech.
126
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
127
+ ref_text (str): Reference text for reference audio. Defaults to None.
128
+ Returns:
129
+ np.ndarray: Generated speech waveform.
130
+ """
131
+
132
+ # Generate tokens
133
+ if self._is_quantized_model:
134
+ output_str = self._infer_ggml(ref_codes, ref_text, text)
135
+ else:
136
+ prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
137
+ output_str = self._infer_torch(prompt_ids)
138
+
139
+ # Decode
140
+ wav = self._decode(output_str)
141
+
142
+ return wav
143
+
144
+ def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
145
+ """
146
+ Perform streaming inference to generate speech from text using the TTS model and reference audio.
147
+
148
+ Args:
149
+ text (str): Input text to be converted to speech.
150
+ ref_codes (np.ndarray | torch.tensor): Encoded reference.
151
+ ref_text (str): Reference text for reference audio. Defaults to None.
152
+ Yields:
153
+ np.ndarray: Generated speech waveform.
154
+ """
155
+
156
+ if self._is_quantized_model:
157
+ return self._infer_stream_ggml(ref_codes, ref_text, text)
158
+ else:
159
+ raise NotImplementedError("Streaming is not implemented for the torch backend!")
160
+
161
+ def encode_reference(self, ref_audio_path: str | Path):
162
+ wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
163
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
164
+ with torch.no_grad():
165
+ ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
166
+ return ref_codes
167
+
168
+ def _decode(self, codes: str):
169
+ """Decode speech tokens to audio waveform."""
170
+ # Extract speech token IDs using regex
171
+ speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
172
+
173
+ if len(speech_ids) == 0:
174
+ raise ValueError(
175
+ "No valid speech tokens found in the output. "
176
+ "The model may not have generated proper speech tokens."
177
+ )
178
+
179
+ # Onnx decode
180
+ if self._is_onnx_codec:
181
+ codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
182
+ recon = self.codec.decode_code(codes)
183
+ # Torch decode
184
+ else:
185
+ with torch.no_grad():
186
+ codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
187
+ self.codec.device
188
+ )
189
+ recon = self.codec.decode_code(codes).cpu().numpy()
190
+
191
+ return recon[0, 0, :]
192
+
193
+ def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
194
+ input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
195
+
196
+ speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
197
+ speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
198
+ text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
199
+ text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
200
+ text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
201
+
202
+ input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
203
+ chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
204
+ ids = self.tokenizer.encode(chat)
205
+
206
+ text_replace_idx = ids.index(text_replace)
207
+ ids = (
208
+ ids[:text_replace_idx]
209
+ + [text_prompt_start]
210
+ + input_ids
211
+ + [text_prompt_end]
212
+ + ids[text_replace_idx + 1 :] # noqa
213
+ )
214
+
215
+ speech_replace_idx = ids.index(speech_replace)
216
+ codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
217
+ codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
218
+ ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
219
+
220
+ return ids
221
+
222
+ def _infer_torch(self, prompt_ids: list[int]) -> str:
223
+ prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
224
+ speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
225
+ with torch.no_grad():
226
+ output_tokens = self.backbone.generate(
227
+ prompt_tensor,
228
+ max_length=self.max_context,
229
+ eos_token_id=speech_end_id,
230
+ do_sample=True,
231
+ temperature=1,
232
+ top_k=50,
233
+ use_cache=True,
234
+ min_new_tokens=50,
235
+ )
236
+ input_length = prompt_tensor.shape[-1]
237
+ output_str = self.tokenizer.decode(
238
+ output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
239
+ )
240
+ return output_str
241
+
242
+ def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
243
+ ref_text = phonemize_with_dict(ref_text)
244
+ input_text = phonemize_with_dict(input_text)
245
+
246
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
247
+ prompt = (
248
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
249
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
250
+ )
251
+ output = self.backbone(
252
+ prompt,
253
+ max_tokens=self.max_context,
254
+ temperature=1.0,
255
+ top_k=50,
256
+ stop=["<|SPEECH_GENERATION_END|>"],
257
+ )
258
+ output_str = output["choices"][0]["text"]
259
+ return output_str
260
+
261
+ def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
262
+ ref_text = phonemize_with_dict(ref_text)
263
+ input_text = phonemize_with_dict(input_text)
264
+
265
+ codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
266
+ prompt = (
267
+ f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
268
+ f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
269
+ )
270
+
271
+ audio_cache: list[np.ndarray] = []
272
+ token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
273
+ n_decoded_samples: int = 0
274
+ n_decoded_tokens: int = len(ref_codes)
275
+
276
+ for item in self.backbone(
277
+ prompt,
278
+ max_tokens=self.max_context,
279
+ temperature=0.2,
280
+ top_k=50,
281
+ stop=["<|SPEECH_GENERATION_END|>"],
282
+ stream=True
283
+ ):
284
+ output_str = item["choices"][0]["text"]
285
+ token_cache.append(output_str)
286
+
287
+ if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
288
+
289
+ # decode chunk
290
+ tokens_start = max(
291
+ n_decoded_tokens
292
+ - self.streaming_lookback
293
+ - self.streaming_overlap_frames,
294
+ 0
295
+ )
296
+ tokens_end = (
297
+ n_decoded_tokens
298
+ + self.streaming_frames_per_chunk
299
+ + self.streaming_lookforward
300
+ + self.streaming_overlap_frames
301
+ )
302
+ sample_start = (
303
+ n_decoded_tokens - tokens_start
304
+ ) * self.hop_length
305
+ sample_end = (
306
+ sample_start
307
+ + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
308
+ )
309
+ curr_codes = token_cache[tokens_start:tokens_end]
310
+ recon = self._decode("".join(curr_codes))
311
+ recon = recon[sample_start:sample_end]
312
+ audio_cache.append(recon)
313
+
314
+ # postprocess
315
+ processed_recon = _linear_overlap_add(
316
+ audio_cache, stride=self.streaming_stride_samples
317
+ )
318
+ new_samples_end = len(audio_cache) * self.streaming_stride_samples
319
+ processed_recon = processed_recon[
320
+ n_decoded_samples:new_samples_end
321
+ ]
322
+ n_decoded_samples = new_samples_end
323
+ n_decoded_tokens += self.streaming_frames_per_chunk
324
+ yield processed_recon
325
+
326
+ # final decoding handled separately as non-constant chunk size
327
+ remaining_tokens = len(token_cache) - n_decoded_tokens
328
+ if len(token_cache) > n_decoded_tokens:
329
+ tokens_start = max(
330
+ len(token_cache)
331
+ - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
332
+ 0
333
+ )
334
+ sample_start = (
335
+ len(token_cache)
336
+ - tokens_start
337
+ - remaining_tokens
338
+ - self.streaming_overlap_frames
339
+ ) * self.hop_length
340
+ curr_codes = token_cache[tokens_start:]
341
+ recon = self._decode("".join(curr_codes))
342
+ recon = recon[sample_start:]
343
+ audio_cache.append(recon)
344
+
345
+ processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
346
+ processed_recon = processed_recon[n_decoded_samples:]
347
+ yield processed_recon