MeloTTS-ONNX 项目详解





1. 项目概述
MeloTTS-ONNX 是 MeloTTS 的 ONNX 推理版本,专门针对 CPU 实时推理进行了优化。项目支持:
- ✅ 中英文混合 TTS
- ✅ 多种语言:中文、英文、日文、韩文、西班牙语、法语等
- ✅ ONNX Runtime 推理,推理速度快
仓库地址
- Git仓库: https://gitee.com/jackroing/melo-tts-onnx.git
- ModelScope模型仓库: https://www.modelscope.cn/models/KeanuX/MeloTTS-ZH-MIXED-EN-ONNX
# 克隆仓库
git clone https://gitee.com/jackroing/melo-tts-onnx.git
# 获取模型
modelscope download --model KeanuX/MeloTTS-ZH-MIXED-EN-ONNX --local_dir ./
2. 项目架构
melo-tts-onnx/
├── melo/ # 原始 PyTorch 训练代码
│ ├── api.py # TTS API 接口
│ ├── models.py # 模型定义 (SynthesizerTrn)
│ ├── modules.py # 模型模块
│ ├── text/ # 文本处理(分词、phoneme转换)
│ │ ├── chinese.py # 中文文本处理
│ │ ├── english.py # 英文文本处理
│ │ └── ...
│ ├── train.py / train.sh # 训练脚本
│ └── ...
│
├── melo_extra/ # 推理时依赖的文本处理模块
│ ├── melo_tts.py # ONNX 导出封装
│ └── inference/
│ ├── text/ # 推理用文本处理(与melo/text对应)
│ ├── commons.py # 通用工具
│ └── utils.py # 参数配置
│
├── models/melotts/ # ONNX 模型目录
│ ├── melotts_14.onnx # 导出的 ONNX 模型
│ ├── config.json # 模型配置
│ └── bert-base-multilingual-uncased/ # BERT 模型
│
├── run_onnx.py # ⭐ 核心推理脚本
├── export_melo.py # ONNX 导出脚本
├── export_model_info.py # 模型信息导出工具
└── README.md
3. 核心脚本使用
3.1 推理脚本:run_onnx.py
这是最常用的脚本,用于将文本转换为语音:
from run_onnx import MeloTTS
# 初始化模型
model_path = "./models/melotts/"
melo_tts = MeloTTS(model_path, device="cpu")
# 生成音频
audio, sr = melo_tts.generate_audio(
text="你好,我是中英混合模型。Hello I am a mixed language model.",
language="ZH_MIX_EN", # 语言: ZH_MIX_EN, EN, JP, KR 等
sdp_ratio=0.2, # SDP 比率
noise_scale=0.667, # 噪声尺度
noise_scale_w=0.8, # 噪声权重
speed=1.0 # 语速
)
主要参数说明:
| 参数 | 说明 | 默认值 |
|---|---|---|
text | 输入文本 | 必填 |
language | 语言代码 | "ZH_MIX_EN" |
sdp_ratio | SDP 比率 (0-1) | 0.2 |
noise_scale | 噪声尺度 | 0.667 |
noise_scale_w | 噪声权重 | 0.8 |
speed | 语速 | 1.0 |
支持的语言代码:
ZH_MIX_EN- 中文(支持中英文混合)EN- 英文- 等
3.2 ONNX 导出脚本:export_melo.py
用于将 PyTorch 模型导出为 ONNX 格式:
python export_melo.py \
-m /path/to/ckpt \
-c /path/to/config.json \
-o /path/to/save_dir \
--opset 14
...
主要参数:
--ckpt_path- 模型检查点路径--cfg_path- 配置文件路径--output_path- 输出 ONNX 文件路径--opset- ONNX opset 版本(默认14)
3.3 模型信息导出:export_model_info.py
用于导出 ONNX 模型的详细信息(输入输出形状、参数数量等):
python export_model_info.py -m ./models/melotts/melotts_14.onnx -o ./infos/melotts_14.info
输出示例:
============================================================
ONNX模型基本信息
============================================================
模型文件路径: ./models/melotts/melotts/melotts_14.onnx
ONNX版本: 7
生产者信息: pytorch 2.8.0
模型版本: 0
描述:
============================================================
模型输入信息 (共 11 个输入)
============================================================
Input 1: x_tst
数据类型: int32
形状: [0, 0]
Input 2: x_tst_lengths
数据类型: int32
形状: [0]
Input 3: speakers
数据类型: int32
形状: [0]
Input 4: tones
数据类型: int32
形状: [0, 0]
Input 5: lang_ids
数据类型: int32
形状: [0, 0]
Input 6: bert
数据类型: float32
形状: [0, 1024, 0]
Input 7: ja_bert
数据类型: float32
形状: [0, 768, 0]
Input 8: sdp_ratio
数据类型: float32
形状: [0]
Input 9: noise_scale
数据类型: float32
形状: [0]
Input 10: noise_scale_w
数据类型: float32
形状: [0]
Input 11: speed
数据类型: float32
形状: [0]
============================================================
模型输出信息 (共 1 个输出)
============================================================
Output 1: audio_data
数据类型: float32
形状: [1, 0]
4. 工作原理
文本输入
↓
┌─────────────────────────────────────────┐
│ 文本预处理 (clean_text) │
│ - 分词 │
│ - 转换为 phoneme │
│ - 获取 tone │
│ - BERT 特征提取 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ ONNX 模型推理 │
│ - Glow-TTS (文本→mel频谱) │
│ - HiFi-GAN (mel频谱→音频) │
└─────────────────────────────────────────┘
↓
音频输出 (44.1kHz)
ONNX 模型输入 (11个):
x_tst- 文本 token IDsx_tst_lengths- 文本长度speakers- 发音人 IDtones- 音调 IDslang_ids- 语言 IDsbert- BERT 特征 (1024维)ja_bert- 日文 BERT 特征 (768维)sdp_ratio- SDP 比率noise_scale- 噪声尺度noise_scale_w- 噪声权重speed- 语速
ONNX 模型输出 (1个):
audio_data- 生成的音频数据
5. 快速使用示例
import onnxruntime as ort
import numpy as np
import os
import sys
import soundfile as sf
from typing import Tuple
from melo_extra.inference.utils import HParams, get_hparams_from_file
from melo_extra.inference.text.cleaner import clean_text
from melo_extra.inference.text import cleaned_text_to_sequence, get_bert, get_zh_mix_en_bert
from melo_extra.inference import commons
import logging
logger = logging.getLogger(__name__)
file_handler = logging.FileHandler("./logs/run_melo_onnx.log", mode="w", encoding="utf-8")
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
logger.setLevel(logging.INFO)
class MeloTTS:
def __init__(self, model_root:str, device:str="cpu", provider_options:list[dict]=None) -> None:
self.model_path = os.path.join(model_root, "melotts_14.onnx")
self.cfg_path = os.path.join(model_root, "config.json")
self.bert_model_path = os.path.join(model_root, "bert-base-multilingual-uncased")
self.cfg = get_hparams_from_file(self.cfg_path)
if device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers():
self.providers = ["CUDAExecutionProvider"]
elif device == "cpu":
self.providers = ["CPUExecutionProvider"]
elif device == "qnn" and "QNNExecutionProvider" in ort.get_available_providers() and provider_options != None:
self.providers = ["QNNExecutionProvider"]
else:
logger.info(f"device {device} not supported, use cpu instead")
self.providers = ["CPUExecutionProvider"]
self.session = ort.InferenceSession(self.model_path, providers=self.providers, provider_options=provider_options)
self.input_names = [input.name for input in self.session.get_inputs()]
self.output_names = [output.name for output in self.session.get_outputs()]
logger.info(f"model input names: {self.input_names}")
logger.info(f"model output names: {self.output_names}")
def __preprocess(self, text:str, language:str):
norm_text, phone, tone, word2ph = clean_text(text, language)
symbol_to_id = {s: i for i, s in enumerate(self.cfg.symbols)}
phone, tone, language = cleaned_text_to_sequence(phone, tone, language, symbol_to_id)
if self.cfg.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
if getattr(self.cfg.data, "disable_bert", True):
bert = np.zeros((1024, len(phone)), dtype=np.float32)
ja_bert = np.zeros((768, len(phone)), dtype=np.float32)
else:
bert = get_zh_mix_en_bert(self.bert_model_path, text, word2ph, "cpu")
del word2ph
assert bert.shape[-1] == len(phone), phone
if language == "ZH":
bert = bert
ja_bert = np.zeros(768, len(phone))
elif language in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
ja_bert = bert
bert = np.zeros(1024, len(phone))
else:
raise NotImplementedError()
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = np.array(phone, dtype=np.int32)
tone = np.array(tone, dtype=np.int32)
language = np.array(language, dtype=np.int32)
x_tst = np.expand_dims(phone, axis=0)
x_tst_lengths = np.array([phone.size], dtype=np.int32)
tones = np.expand_dims(tone, axis=0)
lang_ids = np.expand_dims(language, axis=0)
bert = np.expand_dims(bert, axis=0)
ja_bert = np.expand_dims(ja_bert, axis=0)
speaker_id = np.array([1], dtype=np.int32)
return x_tst, x_tst_lengths, speaker_id, tones, lang_ids, bert, ja_bert
def generate_audio(self,
text:str,
language:str="ZH_MIX_EN",
sdp_ratio:float=0.2,
noise_scale:float=0.667,
noise_scale_w:float=0.8,
speed:float=1.0) -> Tuple[np.ndarray, int]:
"""_summary_
Args:
text (str): User input text
language (str, optional): Language of the text. Defaults to "ZH_MIX_EN".
sdp_ratio (float, optional): Ratio of SDP. Defaults to 0.2.
noise_scale (float, optional): Scale of noise. Defaults to 0.667.
noise_scale_w (float, optional): Weight of noise scale. Defaults to 0.8.
speed (float, optional): Speed of the audio. Defaults to 1.0.
Returns:
Tuple[np.ndarray, int]: Audio data and sample rate
"""
x_tst, x_tst_lengths, speaker_id, tones, lang_ids, bert, ja_bert = self.__preprocess(text, language)
np_sdp_ratio = np.array([sdp_ratio], dtype=np.float32)
np_noise_scale = np.array([noise_scale], dtype=np.float32)
np_noise_scale_w = np.array([noise_scale_w], dtype=np.float32)
np_speed = np.array([speed], dtype=np.float32)
input_spec = {
self.input_names[0]: x_tst,
self.input_names[1]: x_tst_lengths,
self.input_names[2]: speaker_id,
self.input_names[3]: tones,
self.input_names[4]: lang_ids,
self.input_names[5]: bert,
self.input_names[6]: ja_bert,
self.input_names[7]: np_sdp_ratio,
self.input_names[8]: np_noise_scale,
self.input_names[9]: np_noise_scale_w,
self.input_names[10]: np_speed,
}
output_spec = self.session.run(self.output_names, input_spec)[0]
audio_data = np.squeeze(output_spec, axis=0)
return audio_data, 44100
model_path = "./models/melotts/"
if __name__ == "__main__":
melo_tts = MeloTTS(model_path)
audio, sr = melo_tts.generate_audio("你好,我是中英混合模型。Hello I am a mixed language model.我支持数字123")
sf.write("test.wav", audio, sr)
致谢:
- 本项目基于 MeloTTS 项目实现。
- 模型转换使用 Onnx 框架。
- 推理使用 Onnx Runtime 框架。
联系方式:
- 微信公众号:“CrazyNET”
1204

被折叠的 条评论
为什么被折叠?



