MeloTTS-ONNX中英混合模型(支持CPU快速推理)

跟随虾哥项目实践,硬件选小智就对了

xiaozhi 开源方案官方适配,二次开发文档齐全

MeloTTS-ONNX 项目详解

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

1. 项目概述

MeloTTS-ONNXMeloTTS 的 ONNX 推理版本,专门针对 CPU 实时推理进行了优化。项目支持:

  • ✅ 中英文混合 TTS
  • ✅ 多种语言:中文、英文、日文、韩文、西班牙语、法语等
  • ✅ ONNX Runtime 推理,推理速度快

仓库地址
  • Git仓库: https://gitee.com/jackroing/melo-tts-onnx.git
  • ModelScope模型仓库: https://www.modelscope.cn/models/KeanuX/MeloTTS-ZH-MIXED-EN-ONNX
# 克隆仓库
git clone https://gitee.com/jackroing/melo-tts-onnx.git

# 获取模型
modelscope download --model KeanuX/MeloTTS-ZH-MIXED-EN-ONNX --local_dir ./

2. 项目架构

melo-tts-onnx/
├── melo/                      # 原始 PyTorch 训练代码
│   ├── api.py                 # TTS API 接口
│   ├── models.py              # 模型定义 (SynthesizerTrn)
│   ├── modules.py             # 模型模块
│   ├── text/                  # 文本处理(分词、phoneme转换)
│   │   ├── chinese.py         # 中文文本处理
│   │   ├── english.py         # 英文文本处理
│   │   └── ...
│   ├── train.py / train.sh    # 训练脚本
│   └── ...
│
├── melo_extra/                # 推理时依赖的文本处理模块
│   ├── melo_tts.py            # ONNX 导出封装
│   └── inference/
│       ├── text/              # 推理用文本处理(与melo/text对应)
│       ├── commons.py         # 通用工具
│       └── utils.py           # 参数配置
│
├── models/melotts/            # ONNX 模型目录
│   ├── melotts_14.onnx        # 导出的 ONNX 模型
│   ├── config.json            # 模型配置
│   └── bert-base-multilingual-uncased/  # BERT 模型
│
├── run_onnx.py                # ⭐ 核心推理脚本
├── export_melo.py             # ONNX 导出脚本
├── export_model_info.py       # 模型信息导出工具
└── README.md

3. 核心脚本使用

3.1 推理脚本:run_onnx.py

这是最常用的脚本,用于将文本转换为语音:

from run_onnx import MeloTTS

# 初始化模型
model_path = "./models/melotts/"
melo_tts = MeloTTS(model_path, device="cpu")

# 生成音频
audio, sr = melo_tts.generate_audio(
    text="你好,我是中英混合模型。Hello I am a mixed language model.",
    language="ZH_MIX_EN",      # 语言: ZH_MIX_EN, EN, JP, KR 等
    sdp_ratio=0.2,            # SDP 比率
    noise_scale=0.667,        # 噪声尺度
    noise_scale_w=0.8,        # 噪声权重
    speed=1.0                 # 语速
)

主要参数说明:

参数说明默认值
text输入文本必填
language语言代码"ZH_MIX_EN"
sdp_ratioSDP 比率 (0-1)0.2
noise_scale噪声尺度0.667
noise_scale_w噪声权重0.8
speed语速1.0

支持的语言代码:

  • ZH_MIX_EN - 中文(支持中英文混合)
  • EN - 英文

3.2 ONNX 导出脚本:export_melo.py

用于将 PyTorch 模型导出为 ONNX 格式:

python export_melo.py \
    -m /path/to/ckpt \
    -c /path/to/config.json \
    -o /path/to/save_dir \
    --opset 14
    ...

主要参数:

  • --ckpt_path - 模型检查点路径
  • --cfg_path - 配置文件路径
  • --output_path - 输出 ONNX 文件路径
  • --opset - ONNX opset 版本(默认14)

3.3 模型信息导出:export_model_info.py

用于导出 ONNX 模型的详细信息(输入输出形状、参数数量等):

python export_model_info.py -m ./models/melotts/melotts_14.onnx -o ./infos/melotts_14.info

输出示例:

============================================================
ONNX模型基本信息
============================================================
模型文件路径: ./models/melotts/melotts/melotts_14.onnx
ONNX版本: 7
生产者信息: pytorch 2.8.0
模型版本: 0
描述: 

============================================================
模型输入信息 (共 11 个输入)
============================================================
Input 1: x_tst
  数据类型: int32
  形状: [0, 0]

Input 2: x_tst_lengths
  数据类型: int32
  形状: [0]

Input 3: speakers
  数据类型: int32
  形状: [0]

Input 4: tones
  数据类型: int32
  形状: [0, 0]

Input 5: lang_ids
  数据类型: int32
  形状: [0, 0]

Input 6: bert
  数据类型: float32
  形状: [0, 1024, 0]

Input 7: ja_bert
  数据类型: float32
  形状: [0, 768, 0]

Input 8: sdp_ratio
  数据类型: float32
  形状: [0]

Input 9: noise_scale
  数据类型: float32
  形状: [0]

Input 10: noise_scale_w
  数据类型: float32
  形状: [0]

Input 11: speed
  数据类型: float32
  形状: [0]

============================================================
模型输出信息 (共 1 个输出)
============================================================
Output 1: audio_data
  数据类型: float32
  形状: [1, 0]

4. 工作原理

文本输入
   ↓
┌─────────────────────────────────────────┐
│  文本预处理 (clean_text)                 │
│  - 分词                                 │
│  - 转换为 phoneme                       │
│  - 获取 tone                            │
│  - BERT 特征提取                        │
└─────────────────────────────────────────┘
   ↓
┌─────────────────────────────────────────┐
│  ONNX 模型推理                          │
│  - Glow-TTS (文本→mel频谱)             │
│  - HiFi-GAN (mel频谱→音频)             │
└─────────────────────────────────────────┘
   ↓
音频输出 (44.1kHz)

ONNX 模型输入 (11个):

  1. x_tst - 文本 token IDs
  2. x_tst_lengths - 文本长度
  3. speakers - 发音人 ID
  4. tones - 音调 IDs
  5. lang_ids - 语言 IDs
  6. bert - BERT 特征 (1024维)
  7. ja_bert - 日文 BERT 特征 (768维)
  8. sdp_ratio - SDP 比率
  9. noise_scale - 噪声尺度
  10. noise_scale_w - 噪声权重
  11. speed - 语速

ONNX 模型输出 (1个):

  • audio_data - 生成的音频数据

5. 快速使用示例

import onnxruntime as ort
import numpy as np
import os
import sys
import soundfile as sf
from typing import Tuple
from melo_extra.inference.utils import HParams, get_hparams_from_file
from melo_extra.inference.text.cleaner import clean_text
from melo_extra.inference.text import cleaned_text_to_sequence, get_bert, get_zh_mix_en_bert
from melo_extra.inference import commons
import logging

logger = logging.getLogger(__name__)
file_handler = logging.FileHandler("./logs/run_melo_onnx.log", mode="w", encoding="utf-8")
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
logger.setLevel(logging.INFO)

class MeloTTS:
    def __init__(self, model_root:str, device:str="cpu", provider_options:list[dict]=None) -> None:
        self.model_path = os.path.join(model_root, "melotts_14.onnx")
        self.cfg_path = os.path.join(model_root, "config.json")
        self.bert_model_path = os.path.join(model_root, "bert-base-multilingual-uncased")
        
        self.cfg = get_hparams_from_file(self.cfg_path)
        
        if device == "cuda" and "CUDAExecutionProvider" in ort.get_available_providers():
            self.providers = ["CUDAExecutionProvider"]
        elif device == "cpu":
            self.providers = ["CPUExecutionProvider"]
        elif device == "qnn" and "QNNExecutionProvider" in ort.get_available_providers() and provider_options != None:
            self.providers = ["QNNExecutionProvider"]
        else:
            logger.info(f"device {device} not supported, use cpu instead")
            self.providers = ["CPUExecutionProvider"]
        
        self.session = ort.InferenceSession(self.model_path, providers=self.providers, provider_options=provider_options)
        
        self.input_names = [input.name for input in self.session.get_inputs()]
        self.output_names = [output.name for output in self.session.get_outputs()]
        logger.info(f"model input names: {self.input_names}")
        logger.info(f"model output names: {self.output_names}")
    
    def __preprocess(self, text:str, language:str):
        norm_text, phone, tone, word2ph = clean_text(text, language)
        symbol_to_id = {s: i for i, s in enumerate(self.cfg.symbols)}
        phone, tone, language = cleaned_text_to_sequence(phone, tone, language, symbol_to_id)
        
        if self.cfg.data.add_blank:
            phone = commons.intersperse(phone, 0)
            tone = commons.intersperse(tone, 0)
            language = commons.intersperse(language, 0)
            for i in range(len(word2ph)):
                word2ph[i] = word2ph[i] * 2
            word2ph[0] += 1
        
        if getattr(self.cfg.data, "disable_bert", True):
            bert = np.zeros((1024, len(phone)), dtype=np.float32)
            ja_bert = np.zeros((768, len(phone)), dtype=np.float32)
        else:
            bert = get_zh_mix_en_bert(self.bert_model_path, text, word2ph, "cpu")
            del word2ph
            assert bert.shape[-1] == len(phone), phone

            if language == "ZH":
                bert = bert
                ja_bert = np.zeros(768, len(phone))
            elif language in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
                ja_bert = bert
                bert = np.zeros(1024, len(phone))
            else:
                raise NotImplementedError()
        
        assert bert.shape[-1] == len(
            phone
        ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"

        phone = np.array(phone, dtype=np.int32)
        tone = np.array(tone, dtype=np.int32)
        language = np.array(language, dtype=np.int32)
        
        
        x_tst = np.expand_dims(phone, axis=0)
        x_tst_lengths = np.array([phone.size], dtype=np.int32)
        tones = np.expand_dims(tone, axis=0)
        lang_ids = np.expand_dims(language, axis=0)
        
        bert = np.expand_dims(bert, axis=0)
        ja_bert = np.expand_dims(ja_bert, axis=0)
        
        speaker_id = np.array([1], dtype=np.int32)
        
        return x_tst, x_tst_lengths, speaker_id, tones, lang_ids, bert, ja_bert
    
    def generate_audio(self,
                       text:str,
                       language:str="ZH_MIX_EN",
                       sdp_ratio:float=0.2,
                       noise_scale:float=0.667,
                       noise_scale_w:float=0.8,
                       speed:float=1.0) -> Tuple[np.ndarray, int]:
        """_summary_

        Args:
            text (str): User input text
            language (str, optional): Language of the text. Defaults to "ZH_MIX_EN".
            sdp_ratio (float, optional): Ratio of SDP. Defaults to 0.2.
            noise_scale (float, optional): Scale of noise. Defaults to 0.667.
            noise_scale_w (float, optional): Weight of noise scale. Defaults to 0.8.
            speed (float, optional): Speed of the audio. Defaults to 1.0.

        Returns:
            Tuple[np.ndarray, int]: Audio data and sample rate
        """
        x_tst, x_tst_lengths, speaker_id, tones, lang_ids, bert, ja_bert = self.__preprocess(text, language)
        
        np_sdp_ratio = np.array([sdp_ratio], dtype=np.float32)
        np_noise_scale = np.array([noise_scale], dtype=np.float32)
        np_noise_scale_w = np.array([noise_scale_w], dtype=np.float32)
        np_speed = np.array([speed], dtype=np.float32)
        
        input_spec = {
            self.input_names[0]: x_tst,
            self.input_names[1]: x_tst_lengths,
            self.input_names[2]: speaker_id,
            self.input_names[3]: tones,
            self.input_names[4]: lang_ids,
            self.input_names[5]: bert,
            self.input_names[6]: ja_bert,
            self.input_names[7]: np_sdp_ratio,
            self.input_names[8]: np_noise_scale,
            self.input_names[9]: np_noise_scale_w,
            self.input_names[10]: np_speed,
        }
        
        output_spec = self.session.run(self.output_names, input_spec)[0]
        
        audio_data = np.squeeze(output_spec, axis=0)
        
        return audio_data, 44100

model_path = "./models/melotts/"

if __name__ == "__main__":
    melo_tts = MeloTTS(model_path)
    
    audio, sr = melo_tts.generate_audio("你好,我是中英混合模型。Hello I am a mixed language model.我支持数字123")
    
    sf.write("test.wav", audio, sr)
    

致谢:

联系方式:

  • 微信公众号:“CrazyNET”

跟随虾哥项目实践,硬件选小智就对了

xiaozhi 开源方案官方适配,二次开发文档齐全

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值