AI 模型编译优化:从 PyTorch 到 ONNX 到 TensorRT 的推理加速全链路

一、模型训练与推理之间的性能鸿沟
训练一个 AI 模型和部署一个 AI 模型是两个完全不同的工程问题。训练关注的是收敛速度和精度,推理关注的是延迟、吞吐和资源占用。一个在训练时表现良好的模型,直接部署到生产环境往往会遇到严重的性能问题。
核心痛点在于:PyTorch 的动态图机制虽然方便训练,但推理时存在大量开销——动态内存分配、Python 解释器开销、算子未融合导致的冗余内存访问。这些开销在单次推理中可能只有几毫秒,但在高并发场景下会快速放大。
实际数据:一个 BERT-base 模型在 PyTorch 直接推理的延迟约 50ms,经过 ONNX 优化后降到 20ms,再经过 TensorRT 编译优化后可以降到 5ms 以下。10 倍的性能差距,直接决定了服务能否满足 SLA。
二、模型编译优化的分层架构
graph TD
A[PyTorch 模型<br>动态图 + Python] --> B[torch.export<br>导出计算图]
B --> C[ONNX IR<br>中间表示]
C --> D{目标平台?}
D -->|NVIDIA GPU| E[TensorRT<br>GPU 专用编译器]
D -->|通用 CPU| F[ONNX Runtime<br>CPU 推理引擎]
D -->|边缘设备| G[ONNX → TFLite<br>移动端推理]
D -->|浏览器| H[ONNX → WASM<br>Web 推理]
E --> I[算子融合<br>Conv+BN+ReLU → 单算子]
E --> J[精度校准<br>FP32 → INT8 量化]
E --> K[内核自动调优<br>选择最优 CUDA Kernel]
E --> L[内存优化<br>减少显存分配和拷贝]
F --> M[图优化<br>常量折叠/死代码消除]
F --> N[量化<br>动态/静态 INT8]
F --> O[算子优化<br>MHA 融合/矩阵乘法优化]
subgraph 编译优化通用技术
P[算子融合: 减少内存访问次数]
Q[常量折叠: 编译期计算常量表达式]
R[死代码消除: 移除不影响输出的算子]
S[内存规划: 复用中间张量内存]
T[量化: 降低数值精度减少计算量]
end
I --> P
I --> Q
I --> R
I --> S
J --> T
编译优化的核心思路:
算子融合(Operator Fusion):将多个连续算子合并为一个,减少中间结果的内存读写。例如 Conv → BatchNorm → ReLU 融合为单个算子,从 3 次内存访问降为 1 次。
量化(Quantization):将 FP32 权重和激活值降为 INT8 或 FP16,减少内存占用和计算量。INT8 量化通常能带来 2-4 倍的推理加速,但需要校准(Calibration)来保证精度损失在可接受范围内。
内存规划(Memory Planning):分析计算图的生命周期,复用不再需要的中间张量内存,减少总内存分配次数。
内核自动调优(Kernel Auto-Tuning):针对目标硬件,尝试不同的内核实现,选择最快的版本。TensorRT 在构建引擎时会自动执行这一步。
三、生产级实践:PyTorch → ONNX → TensorRT 的完整编译链路
"""
AI 模型编译优化全链路
PyTorch → ONNX → TensorRT (NVIDIA GPU) / ONNX Runtime (CPU)
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
# ===== 模型定义 =====
class TextClassifier(nn.Module):
"""
文本分类模型(示例)
实际项目中替换为业务模型
"""
def __init__(
self,
vocab_size: int = 30000,
hidden_size: int = 256,
num_classes: int = 10,
num_layers: int = 2,
dropout: float = 0.1,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.lstm = nn.LSTM(
hidden_size,
hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
self.dropout = nn.Dropout(dropout)
# 双向 LSTM,输出维度 ×2
self.classifier = nn.Linear(hidden_size * 2, num_classes)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Embedding
x = self.embedding(input_ids)
# LSTM 编码
lstm_out, _ = self.lstm(x)
# 取最后一个非 padding 位置的输出
if attention_mask is not None:
# 找到每个序列最后一个有效 token 的位置
seq_lengths = attention_mask.sum(dim=1) - 1
batch_size = lstm_out.size(0)
last_outputs = lstm_out[
torch.arange(batch_size, device=lstm_out.device),
seq_lengths.long(),
]
else:
last_outputs = lstm_out[:, -1, :]
# 分类
last_outputs = self.dropout(last_outputs)
logits = self.classifier(last_outputs)
return logits
# ===== ONNX 导出 =====
@dataclass
class ExportConfig:
"""导出配置"""
onnx_path: str = "model.onnx"
opset_version: int = 17
dynamic_batch: bool = True # 是否支持动态 batch
max_batch_size: int = 32
seq_length: int = 128 # 固定序列长度
def export_to_onnx(
model: nn.Module,
config: ExportConfig,
) -> Path:
"""
将 PyTorch 模型导出为 ONNX 格式
关键步骤:
1. 设置模型为评估模式(关闭 Dropout、BatchNorm 使用运行统计量)
2. 构造示例输入
3. 使用 torch.onnx.export 导出
4. 验证导出的 ONNX 模型
"""
model.eval()
onnx_path = Path(config.onnx_path)
onnx_path.parent.mkdir(parents=True, exist_ok=True)
# 构造示例输入
dummy_input_ids = torch.randint(
0, 30000, (1, config.seq_length), dtype=torch.long
)
dummy_attention_mask = torch.ones(
1, config.seq_length, dtype=torch.long
)
# 动态维度配置
dynamic_axes = None
if config.dynamic_batch:
dynamic_axes = {
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
}
logger.info("开始导出 ONNX 模型到 %s", onnx_path)
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask),
str(onnx_path),
export_params=True,
opset_version=config.opset_version,
do_constant_folding=True, # 启用常量折叠
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes=dynamic_axes,
)
# 验证 ONNX 模型
import onnx
onnx_model = onnx.load(str(onnx_path))
onnx.checker.check_model(onnx_model)
logger.info("ONNX 模型验证通过")
# 打印模型信息
logger.info(
"ONNX 模型: opset=%d, 输入=%s, 输出=%s",
onnx_model.opset_import[0].version,
[inp.name for inp in onnx_model.graph.input],
[out.name for out in onnx_model.graph.output],
)
return onnx_path
# ===== ONNX Runtime 推理 =====
class ONNXRuntimeInference:
"""ONNX Runtime 推理引擎"""
def __init__(
self,
onnx_path: str,
providers: Optional[list[str]] = None,
):
import onnxruntime as ort
self.providers = providers or ["CPUExecutionProvider"]
self.session = ort.InferenceSession(
onnx_path,
providers=self.providers,
)
# 获取输入输出信息
self.input_names = [
inp.name for inp in self.session.get_inputs()
]
self.output_names = [
out.name for out in self.session.get_outputs()
]
logger.info(
"ONNX Runtime 会话创建: providers=%s, inputs=%s",
self.providers, self.input_names,
)
def predict(
self,
input_ids: np.ndarray,
attention_mask: np.ndarray,
) -> np.ndarray:
"""执行推理"""
inputs = {
"input_ids": input_ids.astype(np.int64),
"attention_mask": attention_mask.astype(np.int64),
}
outputs = self.session.run(self.output_names, inputs)
return outputs[0]
def benchmark(
self,
input_ids: np.ndarray,
attention_mask: np.ndarray,
warmup: int = 10,
runs: int = 100,
) -> dict:
"""性能基准测试"""
# 预热
for _ in range(warmup):
self.predict(input_ids, attention_mask)
# 正式测试
latencies = []
for _ in range(runs):
start = time.perf_counter()
self.predict(input_ids, attention_mask)
latencies.append(
(time.perf_counter() - start) * 1000
)
latencies = np.array(latencies)
return {
"mean_ms": float(latencies.mean()),
"p50_ms": float(np.percentile(latencies, 50)),
"p95_ms": float(np.percentile(latencies, 95)),
"p99_ms": float(np.percentile(latencies, 99)),
}
# ===== TensorRT 编译(NVIDIA GPU) =====
class TensorRTEngine:
"""
TensorRT 编译引擎
将 ONNX 模型编译为 TensorRT 引擎,实现 GPU 推理加速
注意:需要在 NVIDIA GPU 环境下运行
依赖:tensorrt, polygraphy(可选,用于精度校准)
"""
def __init__(
self,
onnx_path: str,
engine_path: str,
precision: str = "fp16", # fp32 / fp16 / int8
max_batch_size: int = 32,
calibration_data: Optional[np.ndarray] = None,
):
self.onnx_path = onnx_path
self.engine_path = engine_path
self.precision = precision
self.max_batch_size = max_batch_size
self.calibration_data = calibration_data
self.engine = None
self.context = None
def build(self) -> None:
"""构建 TensorRT 引擎"""
try:
import tensorrt as trt
except ImportError:
logger.error(
"TensorRT 未安装,请参考 "
"https://developer.nvidia.com/tensorrt"
)
raise
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
# 解析 ONNX 模型
with open(self.onnx_path, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
logger.log(
trt.Logger.ERROR,
parser.get_error(i).desc(),
)
raise RuntimeError("ONNX 解析失败")
# 配置构建器
config = builder.create_builder_config()
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, 1 << 30 # 1GB 工作空间
)
# 精度设置
if self.precision == "fp16":
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
logger.info("启用 FP16 精度")
else:
logger.warning("平台不支持 FP16,使用 FP32")
elif self.precision == "int8":
if builder.platform_has_fast_int8:
config.set_flag(trt.BuilderFlag.INT8)
logger.info("启用 INT8 精度")
# INT8 需要校准数据
if self.calibration_data is not None:
calibrator = self._create_calibrator()
config.int8_calibrator = calibrator
else:
logger.warning(
"INT8 模式未提供校准数据,"
"精度可能受影响"
)
else:
logger.warning("平台不支持 INT8,使用 FP32")
# 构建引擎(这一步会自动进行内核调优,耗时较长)
logger.info("开始构建 TensorRT 引擎(可能需要几分钟)...")
plan = builder.build_serialized_network(network, config)
if plan is None:
raise RuntimeError("TensorRT 引擎构建失败")
# 保存引擎
with open(self.engine_path, "wb") as f:
f.write(plan)
logger.info("引擎已保存到 %s", self.engine_path)
# 加载引擎
runtime = trt.Runtime(logger)
self.engine = runtime.deserialize_cuda_engine(plan)
self.context = self.engine.create_execution_context()
def _create_calibrator(self):
"""创建 INT8 校准器(简化实现)"""
import tensorrt as trt
class CalibrationDataLoader(trt.IInt8EntropyCalibrator2):
def __init__(self, data):
self.data = data
self.current_idx = 0
self.device_input = None
def get_batch_size(self):
return 1
def get_batch(self, names):
if self.current_idx >= len(self.data):
return None
batch = self.data[self.current_idx]
self.current_idx += 1
import pycuda.driver as cuda
import pycuda.autoinit
self.device_input = cuda.mem_alloc(
batch.nbytes
)
cuda.memcpy_htod(
self.device_input, batch
)
return [int(self.device_input)]
def read_calibration_cache(self):
return None
def write_calibration_cache(self, cache):
pass
return CalibrationDataLoader(self.calibration_data)
# ===== 性能对比 =====
def compare_performance(
model: nn.Module,
config: ExportConfig,
batch_size: int = 1,
seq_length: int = 128,
) -> None:
"""对比 PyTorch / ONNX Runtime / TensorRT 的推理性能"""
# 准备测试数据
input_ids = torch.randint(
0, 30000, (batch_size, seq_length), dtype=torch.long
)
attention_mask = torch.ones(
batch_size, seq_length, dtype=torch.long
)
# 1. PyTorch 基准
model.eval()
with torch.no_grad():
# 预热
for _ in range(10):
model(input_ids, attention_mask)
# 测试
latencies = []
for _ in range(100):
start = time.perf_counter()
model(input_ids, attention_mask)
latencies.append(
(time.perf_counter() - start) * 1000
)
pytorch_p50 = np.percentile(latencies, 50)
print(f"PyTorch: P50 = {pytorch_p50:.2f} ms")
# 2. ONNX Runtime 基准
onnx_path = export_to_onnx(model, config)
ort_engine = ONNXRuntimeInference(str(onnx_path))
ort_result = ort_engine.benchmark(
input_ids.numpy(),
attention_mask.numpy(),
)
print(f"ONNX Runtime: P50 = {ort_result['p50_ms']:.2f} ms")
# 3. 加速比
speedup = pytorch_p50 / ort_result['p50_ms']
print(f"ONNX Runtime 加速比: {speedup:.2f}x")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# 创建模型
model = TextClassifier(
vocab_size=30000,
hidden_size=256,
num_classes=10,
)
model.eval()
# 导出 ONNX
config = ExportConfig(
onnx_path="./models/classifier.onnx",
dynamic_batch=True,
)
onnx_path = export_to_onnx(model, config)
# ONNX Runtime 推理测试
ort_engine = ONNXRuntimeInference(str(onnx_path))
test_input = np.random.randint(
0, 30000, (4, 128), dtype=np.int64
)
test_mask = np.ones((4, 128), dtype=np.int64)
output = ort_engine.predict(test_input, test_mask)
print(f"推理输出形状: {output.shape}")
# 性能对比
compare_performance(model, config)
踩坑记录:torch.onnx.export 对动态形状的支持有限。如果模型中有条件分支(如 if 语句依赖输入值),导出时会失败或产生不正确的结果。解决方案是使用 torch.jit.trace(只记录执行路径)而非 torch.jit.script,或者重构模型避免条件分支。
INT8 量化的精度损失是一个需要仔细评估的问题。在文本分类任务中,INT8 量化的精度损失通常小于 1%,但在检测和分割任务中可能达到 3%-5%。建议在量化前后都跑一遍评估集,确保精度在可接受范围内。
TensorRT 的引擎是硬件绑定的——在 A100 上编译的引擎不能在 V100 上使用。这意味着每个目标硬件都需要单独编译。在 CI/CD 中,需要在目标硬件上执行编译步骤。
四、模型编译优化的代价与适用边界
编译耗时。 TensorRT 的引擎构建过程包括内核自动调优,通常需要几分钟到几十分钟。这意味着模型更新后不能立即部署,需要预留编译时间。
硬件绑定。 TensorRT 引擎与 GPU 架构强绑定,不同型号的 GPU 需要分别编译。这增加了部署和运维的复杂度。
调试困难。 编译后的模型是黑盒,无法像 PyTorch 那样逐步调试。精度问题需要通过对比工具(如 PolyGraphy)逐层排查。
适用场景:
- 高吞吐低延迟的在线推理服务
- NVIDIA GPU 部署场景
- 模型固定、不频繁更新的生产环境
- 对推理成本敏感的大规模部署
不适用场景:
- 模型频繁迭代的研发阶段
- 非 NVIDIA 硬件(用 ONNX Runtime 或 OpenVINO)
- 需要灵活调试和可视化的开发环境
- 小规模部署(编译优化的 ROI 不高)
五、总结
AI 模型编译优化的核心链路是 PyTorch → ONNX → TensorRT/ONNX Runtime,关键优化技术包括算子融合、量化、常量折叠和内存规划。ONNX 作为中间表示实现了框架解耦,TensorRT 提供 NVIDIA GPU 上的极致性能,ONNX Runtime 覆盖通用 CPU 场景。编译优化的代价是编译耗时、硬件绑定和调试困难,适用于模型固定、高吞吐低延迟的生产部署场景,不适用于频繁迭代和非 NVIDIA 硬件环境。INT8 量化需要校准和精度验证,TensorRT 引擎需要针对目标硬件单独编译。
794

被折叠的 条评论
为什么被折叠?



