AI 模型编译优化:从 PyTorch 到 ONNX 到 TensorRT 的推理加速全链路

AI 模型编译优化:从 PyTorch 到 ONNX 到 TensorRT 的推理加速全链路

cover

一、模型训练与推理之间的性能鸿沟

训练一个 AI 模型和部署一个 AI 模型是两个完全不同的工程问题。训练关注的是收敛速度和精度,推理关注的是延迟、吞吐和资源占用。一个在训练时表现良好的模型,直接部署到生产环境往往会遇到严重的性能问题。

核心痛点在于:PyTorch 的动态图机制虽然方便训练,但推理时存在大量开销——动态内存分配、Python 解释器开销、算子未融合导致的冗余内存访问。这些开销在单次推理中可能只有几毫秒,但在高并发场景下会快速放大。

实际数据:一个 BERT-base 模型在 PyTorch 直接推理的延迟约 50ms,经过 ONNX 优化后降到 20ms,再经过 TensorRT 编译优化后可以降到 5ms 以下。10 倍的性能差距,直接决定了服务能否满足 SLA。

二、模型编译优化的分层架构

graph TD
    A[PyTorch 模型<br>动态图 + Python] --> B[torch.export<br>导出计算图]
    B --> C[ONNX IR<br>中间表示]

    C --> D{目标平台?}
    D -->|NVIDIA GPU| E[TensorRT<br>GPU 专用编译器]
    D -->|通用 CPU| F[ONNX Runtime<br>CPU 推理引擎]
    D -->|边缘设备| G[ONNX → TFLite<br>移动端推理]
    D -->|浏览器| H[ONNX → WASM<br>Web 推理]

    E --> I[算子融合<br>Conv+BN+ReLU → 单算子]
    E --> J[精度校准<br>FP32 → INT8 量化]
    E --> K[内核自动调优<br>选择最优 CUDA Kernel]
    E --> L[内存优化<br>减少显存分配和拷贝]

    F --> M[图优化<br>常量折叠/死代码消除]
    F --> N[量化<br>动态/静态 INT8]
    F --> O[算子优化<br>MHA 融合/矩阵乘法优化]

    subgraph 编译优化通用技术
        P[算子融合: 减少内存访问次数]
        Q[常量折叠: 编译期计算常量表达式]
        R[死代码消除: 移除不影响输出的算子]
        S[内存规划: 复用中间张量内存]
        T[量化: 降低数值精度减少计算量]
    end

    I --> P
    I --> Q
    I --> R
    I --> S
    J --> T

编译优化的核心思路:

  1. 算子融合(Operator Fusion):将多个连续算子合并为一个,减少中间结果的内存读写。例如 Conv → BatchNorm → ReLU 融合为单个算子,从 3 次内存访问降为 1 次。

  2. 量化(Quantization):将 FP32 权重和激活值降为 INT8 或 FP16,减少内存占用和计算量。INT8 量化通常能带来 2-4 倍的推理加速,但需要校准(Calibration)来保证精度损失在可接受范围内。

  3. 内存规划(Memory Planning):分析计算图的生命周期,复用不再需要的中间张量内存,减少总内存分配次数。

  4. 内核自动调优(Kernel Auto-Tuning):针对目标硬件,尝试不同的内核实现,选择最快的版本。TensorRT 在构建引擎时会自动执行这一步。

三、生产级实践:PyTorch → ONNX → TensorRT 的完整编译链路

"""
AI 模型编译优化全链路
PyTorch → ONNX → TensorRT (NVIDIA GPU) / ONNX Runtime (CPU)
"""
from __future__ import annotations

import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn as nn

logger = logging.getLogger(__name__)


# ===== 模型定义 =====

class TextClassifier(nn.Module):
    """
    文本分类模型(示例)
    实际项目中替换为业务模型
    """

    def __init__(
        self,
        vocab_size: int = 30000,
        hidden_size: int = 256,
        num_classes: int = 10,
        num_layers: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(
            hidden_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )
        self.dropout = nn.Dropout(dropout)
        # 双向 LSTM,输出维度 ×2
        self.classifier = nn.Linear(hidden_size * 2, num_classes)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Embedding
        x = self.embedding(input_ids)

        # LSTM 编码
        lstm_out, _ = self.lstm(x)

        # 取最后一个非 padding 位置的输出
        if attention_mask is not None:
            # 找到每个序列最后一个有效 token 的位置
            seq_lengths = attention_mask.sum(dim=1) - 1
            batch_size = lstm_out.size(0)
            last_outputs = lstm_out[
                torch.arange(batch_size, device=lstm_out.device),
                seq_lengths.long(),
            ]
        else:
            last_outputs = lstm_out[:, -1, :]

        # 分类
        last_outputs = self.dropout(last_outputs)
        logits = self.classifier(last_outputs)
        return logits


# ===== ONNX 导出 =====

@dataclass
class ExportConfig:
    """导出配置"""
    onnx_path: str = "model.onnx"
    opset_version: int = 17
    dynamic_batch: bool = True     # 是否支持动态 batch
    max_batch_size: int = 32
    seq_length: int = 128          # 固定序列长度


def export_to_onnx(
    model: nn.Module,
    config: ExportConfig,
) -> Path:
    """
    将 PyTorch 模型导出为 ONNX 格式

    关键步骤:
    1. 设置模型为评估模式(关闭 Dropout、BatchNorm 使用运行统计量)
    2. 构造示例输入
    3. 使用 torch.onnx.export 导出
    4. 验证导出的 ONNX 模型
    """
    model.eval()
    onnx_path = Path(config.onnx_path)
    onnx_path.parent.mkdir(parents=True, exist_ok=True)

    # 构造示例输入
    dummy_input_ids = torch.randint(
        0, 30000, (1, config.seq_length), dtype=torch.long
    )
    dummy_attention_mask = torch.ones(
        1, config.seq_length, dtype=torch.long
    )

    # 动态维度配置
    dynamic_axes = None
    if config.dynamic_batch:
        dynamic_axes = {
            "input_ids": {0: "batch_size"},
            "attention_mask": {0: "batch_size"},
            "logits": {0: "batch_size"},
        }

    logger.info("开始导出 ONNX 模型到 %s", onnx_path)

    torch.onnx.export(
        model,
        (dummy_input_ids, dummy_attention_mask),
        str(onnx_path),
        export_params=True,
        opset_version=config.opset_version,
        do_constant_folding=True,   # 启用常量折叠
        input_names=["input_ids", "attention_mask"],
        output_names=["logits"],
        dynamic_axes=dynamic_axes,
    )

    # 验证 ONNX 模型
    import onnx
    onnx_model = onnx.load(str(onnx_path))
    onnx.checker.check_model(onnx_model)
    logger.info("ONNX 模型验证通过")

    # 打印模型信息
    logger.info(
        "ONNX 模型: opset=%d, 输入=%s, 输出=%s",
        onnx_model.opset_import[0].version,
        [inp.name for inp in onnx_model.graph.input],
        [out.name for out in onnx_model.graph.output],
    )

    return onnx_path


# ===== ONNX Runtime 推理 =====

class ONNXRuntimeInference:
    """ONNX Runtime 推理引擎"""

    def __init__(
        self,
        onnx_path: str,
        providers: Optional[list[str]] = None,
    ):
        import onnxruntime as ort

        self.providers = providers or ["CPUExecutionProvider"]
        self.session = ort.InferenceSession(
            onnx_path,
            providers=self.providers,
        )

        # 获取输入输出信息
        self.input_names = [
            inp.name for inp in self.session.get_inputs()
        ]
        self.output_names = [
            out.name for out in self.session.get_outputs()
        ]

        logger.info(
            "ONNX Runtime 会话创建: providers=%s, inputs=%s",
            self.providers, self.input_names,
        )

    def predict(
        self,
        input_ids: np.ndarray,
        attention_mask: np.ndarray,
    ) -> np.ndarray:
        """执行推理"""
        inputs = {
            "input_ids": input_ids.astype(np.int64),
            "attention_mask": attention_mask.astype(np.int64),
        }
        outputs = self.session.run(self.output_names, inputs)
        return outputs[0]

    def benchmark(
        self,
        input_ids: np.ndarray,
        attention_mask: np.ndarray,
        warmup: int = 10,
        runs: int = 100,
    ) -> dict:
        """性能基准测试"""
        # 预热
        for _ in range(warmup):
            self.predict(input_ids, attention_mask)

        # 正式测试
        latencies = []
        for _ in range(runs):
            start = time.perf_counter()
            self.predict(input_ids, attention_mask)
            latencies.append(
                (time.perf_counter() - start) * 1000
            )

        latencies = np.array(latencies)
        return {
            "mean_ms": float(latencies.mean()),
            "p50_ms": float(np.percentile(latencies, 50)),
            "p95_ms": float(np.percentile(latencies, 95)),
            "p99_ms": float(np.percentile(latencies, 99)),
        }


# ===== TensorRT 编译(NVIDIA GPU) =====

class TensorRTEngine:
    """
    TensorRT 编译引擎
    将 ONNX 模型编译为 TensorRT 引擎,实现 GPU 推理加速

    注意:需要在 NVIDIA GPU 环境下运行
    依赖:tensorrt, polygraphy(可选,用于精度校准)
    """

    def __init__(
        self,
        onnx_path: str,
        engine_path: str,
        precision: str = "fp16",    # fp32 / fp16 / int8
        max_batch_size: int = 32,
        calibration_data: Optional[np.ndarray] = None,
    ):
        self.onnx_path = onnx_path
        self.engine_path = engine_path
        self.precision = precision
        self.max_batch_size = max_batch_size
        self.calibration_data = calibration_data
        self.engine = None
        self.context = None

    def build(self) -> None:
        """构建 TensorRT 引擎"""
        try:
            import tensorrt as trt
        except ImportError:
            logger.error(
                "TensorRT 未安装,请参考 "
                "https://developer.nvidia.com/tensorrt"
            )
            raise

        logger = trt.Logger(trt.Logger.WARNING)
        builder = trt.Builder(logger)
        network = builder.create_network(
            1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        )
        parser = trt.OnnxParser(network, logger)

        # 解析 ONNX 模型
        with open(self.onnx_path, "rb") as f:
            if not parser.parse(f.read()):
                for i in range(parser.num_errors):
                    logger.log(
                        trt.Logger.ERROR,
                        parser.get_error(i).desc(),
                    )
                raise RuntimeError("ONNX 解析失败")

        # 配置构建器
        config = builder.create_builder_config()
        config.set_memory_pool_limit(
            trt.MemoryPoolType.WORKSPACE, 1 << 30  # 1GB 工作空间
        )

        # 精度设置
        if self.precision == "fp16":
            if builder.platform_has_fast_fp16:
                config.set_flag(trt.BuilderFlag.FP16)
                logger.info("启用 FP16 精度")
            else:
                logger.warning("平台不支持 FP16,使用 FP32")

        elif self.precision == "int8":
            if builder.platform_has_fast_int8:
                config.set_flag(trt.BuilderFlag.INT8)
                logger.info("启用 INT8 精度")
                # INT8 需要校准数据
                if self.calibration_data is not None:
                    calibrator = self._create_calibrator()
                    config.int8_calibrator = calibrator
                else:
                    logger.warning(
                        "INT8 模式未提供校准数据,"
                        "精度可能受影响"
                    )
            else:
                logger.warning("平台不支持 INT8,使用 FP32")

        # 构建引擎(这一步会自动进行内核调优,耗时较长)
        logger.info("开始构建 TensorRT 引擎(可能需要几分钟)...")
        plan = builder.build_serialized_network(network, config)
        if plan is None:
            raise RuntimeError("TensorRT 引擎构建失败")

        # 保存引擎
        with open(self.engine_path, "wb") as f:
            f.write(plan)
        logger.info("引擎已保存到 %s", self.engine_path)

        # 加载引擎
        runtime = trt.Runtime(logger)
        self.engine = runtime.deserialize_cuda_engine(plan)
        self.context = self.engine.create_execution_context()

    def _create_calibrator(self):
        """创建 INT8 校准器(简化实现)"""
        import tensorrt as trt

        class CalibrationDataLoader(trt.IInt8EntropyCalibrator2):
            def __init__(self, data):
                self.data = data
                self.current_idx = 0
                self.device_input = None

            def get_batch_size(self):
                return 1

            def get_batch(self, names):
                if self.current_idx >= len(self.data):
                    return None
                batch = self.data[self.current_idx]
                self.current_idx += 1
                import pycuda.driver as cuda
                import pycuda.autoinit
                self.device_input = cuda.mem_alloc(
                    batch.nbytes
                )
                cuda.memcpy_htod(
                    self.device_input, batch
                )
                return [int(self.device_input)]

            def read_calibration_cache(self):
                return None

            def write_calibration_cache(self, cache):
                pass

        return CalibrationDataLoader(self.calibration_data)


# ===== 性能对比 =====

def compare_performance(
    model: nn.Module,
    config: ExportConfig,
    batch_size: int = 1,
    seq_length: int = 128,
) -> None:
    """对比 PyTorch / ONNX Runtime / TensorRT 的推理性能"""

    # 准备测试数据
    input_ids = torch.randint(
        0, 30000, (batch_size, seq_length), dtype=torch.long
    )
    attention_mask = torch.ones(
        batch_size, seq_length, dtype=torch.long
    )

    # 1. PyTorch 基准
    model.eval()
    with torch.no_grad():
        # 预热
        for _ in range(10):
            model(input_ids, attention_mask)
        # 测试
        latencies = []
        for _ in range(100):
            start = time.perf_counter()
            model(input_ids, attention_mask)
            latencies.append(
                (time.perf_counter() - start) * 1000
            )
    pytorch_p50 = np.percentile(latencies, 50)
    print(f"PyTorch: P50 = {pytorch_p50:.2f} ms")

    # 2. ONNX Runtime 基准
    onnx_path = export_to_onnx(model, config)
    ort_engine = ONNXRuntimeInference(str(onnx_path))
    ort_result = ort_engine.benchmark(
        input_ids.numpy(),
        attention_mask.numpy(),
    )
    print(f"ONNX Runtime: P50 = {ort_result['p50_ms']:.2f} ms")

    # 3. 加速比
    speedup = pytorch_p50 / ort_result['p50_ms']
    print(f"ONNX Runtime 加速比: {speedup:.2f}x")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    # 创建模型
    model = TextClassifier(
        vocab_size=30000,
        hidden_size=256,
        num_classes=10,
    )
    model.eval()

    # 导出 ONNX
    config = ExportConfig(
        onnx_path="./models/classifier.onnx",
        dynamic_batch=True,
    )
    onnx_path = export_to_onnx(model, config)

    # ONNX Runtime 推理测试
    ort_engine = ONNXRuntimeInference(str(onnx_path))
    test_input = np.random.randint(
        0, 30000, (4, 128), dtype=np.int64
    )
    test_mask = np.ones((4, 128), dtype=np.int64)
    output = ort_engine.predict(test_input, test_mask)
    print(f"推理输出形状: {output.shape}")

    # 性能对比
    compare_performance(model, config)

踩坑记录:torch.onnx.export 对动态形状的支持有限。如果模型中有条件分支(如 if 语句依赖输入值),导出时会失败或产生不正确的结果。解决方案是使用 torch.jit.trace(只记录执行路径)而非 torch.jit.script,或者重构模型避免条件分支。

INT8 量化的精度损失是一个需要仔细评估的问题。在文本分类任务中,INT8 量化的精度损失通常小于 1%,但在检测和分割任务中可能达到 3%-5%。建议在量化前后都跑一遍评估集,确保精度在可接受范围内。

TensorRT 的引擎是硬件绑定的——在 A100 上编译的引擎不能在 V100 上使用。这意味着每个目标硬件都需要单独编译。在 CI/CD 中,需要在目标硬件上执行编译步骤。

四、模型编译优化的代价与适用边界

编译耗时。 TensorRT 的引擎构建过程包括内核自动调优,通常需要几分钟到几十分钟。这意味着模型更新后不能立即部署,需要预留编译时间。

硬件绑定。 TensorRT 引擎与 GPU 架构强绑定,不同型号的 GPU 需要分别编译。这增加了部署和运维的复杂度。

调试困难。 编译后的模型是黑盒,无法像 PyTorch 那样逐步调试。精度问题需要通过对比工具(如 PolyGraphy)逐层排查。

适用场景:

  • 高吞吐低延迟的在线推理服务
  • NVIDIA GPU 部署场景
  • 模型固定、不频繁更新的生产环境
  • 对推理成本敏感的大规模部署

不适用场景:

  • 模型频繁迭代的研发阶段
  • 非 NVIDIA 硬件(用 ONNX Runtime 或 OpenVINO)
  • 需要灵活调试和可视化的开发环境
  • 小规模部署(编译优化的 ROI 不高)

五、总结

AI 模型编译优化的核心链路是 PyTorch → ONNX → TensorRT/ONNX Runtime,关键优化技术包括算子融合、量化、常量折叠和内存规划。ONNX 作为中间表示实现了框架解耦,TensorRT 提供 NVIDIA GPU 上的极致性能,ONNX Runtime 覆盖通用 CPU 场景。编译优化的代价是编译耗时、硬件绑定和调试困难,适用于模型固定、高吞吐低延迟的生产部署场景,不适用于频繁迭代和非 NVIDIA 硬件环境。INT8 量化需要校准和精度验证,TensorRT 引擎需要针对目标硬件单独编译。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值