WASM AI 模型打包:从 PyTorch 到浏览器端推理,WebAssembly 的 AI 落地之路

一、浏览器端推理的工程动机:为什么要把模型跑在前端
传统 AI 推理架构是"前端发请求、后端跑模型",但这有三个痛点:延迟(网络往返 100-500ms)、成本(GPU 服务器按小时计费)、隐私(用户数据必须上传到服务端)。浏览器端推理可以同时解决这三个问题——模型在用户设备上运行,零网络延迟、零服务端成本、数据不出浏览器。
WebAssembly(WASM)让浏览器端推理成为可能:它以接近原生的速度执行编译后的二进制代码,支持 SIMD 和多线程,可以运行 C/Rust 编译的推理引擎。但 WASM 也有硬限制——无法直接访问 GPU(WebGPU 尚在普及)、内存上限通常 2-4GB、无法使用系统级库。这些限制决定了哪些 AI 模型适合 WASM 部署。
二、WASM AI 推理的技术架构:从模型到浏览器
flowchart TB
A[PyTorch/ONNX 模型] --> B[模型优化]
B --> B1[量化 INT8/FP16]
B --> B2[图优化 算子融合]
B --> B3[模型裁剪 剪枝]
B --> C[WASM 编译]
C --> C1[ONNX Runtime Web]
C --> C2[llama.cpp WASM]
C --> C3[自定义 Rust 推理引擎]
C1 --> D[浏览器加载]
C2 --> D
C3 --> D
D --> E{推理执行}
E -->|WASM SIMD| F[CPU 推理]
E -->|WebGPU| G[GPU 推理 实验性]
F --> H[推理结果]
G --> H
style B fill:#ff6b6b,color:#fff
style C fill:#ffd93d,color:#333
style F fill:#6bcb77,color:#fff
技术选型的关键约束:
- 模型大小:WASM 模块 + 模型权重需要通过 HTTP 下载到浏览器。超过 50MB 的模型加载时间过长(3G 网络 30 秒+),用户体验差。适合 WASM 的是轻量模型(<30MB),如 MobileBERT、SqueezeNet、小型 LLM 的量化版本。
- 内存限制:浏览器对 WASM 的内存上限通常为 2-4GB。模型权重 + 运行时内存 + KV Cache 必须在此范围内。INT4 量化的 1.5B 模型约需 1GB 内存,是 WASM 推理的上限。
- 计算性能:WASM SIMD 提供 128-bit 向量指令(相当于 SSE2),性能约为原生 CPU 的 70%-80%。WebGPU 可以利用 GPU 加速,但浏览器兼容性有限(Chrome 113+ 支持)。
三、WASM AI 推理实现
<!-- ONNX Runtime Web — 浏览器端推理 -->
<!DOCTYPE html>
<html>
<head>
<title>WASM AI 推理示例</title>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
</head>
<body>
<h2>浏览器端文本分类</h2>
<textarea id="input" rows="4" cols="50">这条评论是正面的还是负面的?</textarea>
<br>
<button onclick="runInference()">推理</button>
<p id="result">等待推理...</p>
<script>
// 配置 ONNX Runtime Web 使用 WASM 后端
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/';
ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
ort.env.wasm.simd = true; // 启用 SIMD 加速
let session = null;
// 异步加载模型(首次加载较慢,建议缓存到 IndexedDB)
async function loadModel() {
const modelUrl = './models/text-classifier-quantized.onnx';
const options = {
executionProviders: ['wasm'], // 使用 WASM 后端
graphOptimizationLevel: 'all',
};
try {
session = await ort.InferenceSession.create(modelUrl, options);
console.log('模型加载成功');
} catch (e) {
console.error('模型加载失败:', e);
}
}
// 执行推理
async function runInference() {
if (!session) {
document.getElementById('result').textContent = '模型加载中...';
await loadModel();
}
const inputText = document.getElementById('input').value;
// 文本预处理(简化:实际需要 Tokenizer)
const inputIds = tokenize(inputText);
const attentionMask = new Array(inputIds.length).fill(1);
// 创建输入 Tensor
const inputIdsTensor = new ort.Tensor(
'int64', BigInt64Array.from(inputIds.map(BigInt)),
[1, inputIds.length]
);
const attentionMaskTensor = new ort.Tensor(
'int64', BigInt64Array.from(attentionMask.map(BigInt)),
[1, attentionMask.length]
);
const feeds = {
input_ids: inputIdsTensor,
attention_mask: attentionMaskTensor,
};
try {
const start = performance.now();
const results = await session.run(feeds);
const elapsed = performance.now() - start;
// 解析输出
const logits = results.logits.data;
const positiveProb = softmax(logits)[1];
document.getElementById('result').textContent =
`正面概率: ${(positiveProb * 100).toFixed(1)}% ` +
`(推理耗时: ${elapsed.toFixed(0)}ms)`;
} catch (e) {
document.getElementById('result').textContent = `推理失败: ${e}`;
}
}
// 简化 Tokenizer(生产环境需用真实 Tokenizer)
function tokenize(text) {
// 实际应使用与训练时一致的 Tokenizer
return Array.from(text).map(c => c.charCodeAt(0));
}
function softmax(logits) {
const max = Math.max(...logits);
const exps = logits.map(x => Math.exp(x - max));
const sum = exps.reduce((a, b) => a + b, 0);
return exps.map(x => x / sum);
}
// 页面加载时预加载模型
loadModel();
</script>
</body>
</html>
// Rust → WASM 自定义推理引擎
// Cargo.toml 需要配置 crate-type = ["cdylib"]
use wasm_bindgen::prelude::*;
use web_sys::console;
/// 简单的矩阵乘法推理引擎(演示 WASM SIMD 加速)
#[wasm_bindgen]
pub struct WasmInferenceEngine {
weights: Vec<f32>,
bias: Vec<f32>,
input_dim: usize,
output_dim: usize,
}
#[wasm_bindgen]
impl WasmInferenceEngine {
/// 从字节数组加载模型权重
#[wasm_bindgen(constructor)]
pub fn new(weights: &[u8], bias: &[u8],
input_dim: usize, output_dim: usize) -> Self {
// 将字节数组解析为 f32 数组
let weights = bytes_to_f32(weights);
let bias = bytes_to_f32(bias);
WasmInferenceEngine {
weights, bias, input_dim, output_dim,
}
}
/// 执行单层前向推理:output = input × weights + bias
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut output = vec![0.0f32; self.output_dim];
for j in 0..self.output_dim {
let mut sum = self.bias[j];
for i in 0..self.input_dim {
sum += input[i] *
self.weights[i * self.output_dim + j];
}
output[j] = sum;
}
output
}
/// 带激活函数的前向推理:ReLU(input × weights + bias)
pub fn forward_relu(&self, input: &[f32]) -> Vec<f32> {
let mut output = self.forward(input);
for val in output.iter_mut() {
*val = val.max(0.0); // ReLU
}
output
}
}
fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
bytes.chunks_exact(4)
.map(|chunk| {
f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
})
.collect()
}
# Rust → WASM 编译命令
# 需要安装 wasm-pack
wasm-pack build --target web --release
# 编译产物:
# - pkg/inference_engine.js (JS 绑定)
# - pkg/inference_engine_bg.wasm (WASM 二进制)
# - pkg/inference_engine.d.ts (TypeScript 类型)
四、WASM AI 推理的边界与限制
模型加载时间:30MB 的模型在 4G 网络下需要约 6 秒下载,加上 WASM 编译和初始化,首次加载可能需要 10 秒以上。解决方案是用 IndexedDB 缓存模型和 WASM 模块,二次访问时从本地加载。也可以按需加载——只加载用户当前需要的模型。
内存压力:浏览器标签页的内存上限通常为 2-4GB,WASM 线性内存 + JS 堆 + 模型权重共享这个上限。如果用户同时打开多个标签页,内存可能不足。解决方案是使用 SharedArrayBuffer 在标签页间共享 WASM 内存(需要 COOP/COEP 头),或在模型不使用时释放内存。
SIMD 的兼容性:WASM SIMD 在 Chrome 91+、Firefox 89+、Safari 16.4+ 支持,但部分旧浏览器不支持。需要提供非 SIMD 的回退版本,或在运行时检测 SIMD 支持并选择对应的 WASM 二进制。
WebGPU 的不确定性:WebGPU 可以在浏览器中利用 GPU 加速推理,但目前仅在 Chrome 113+ 稳定支持,Safari 和 Firefox 支持不完整。生产环境中 WebGPU 只能作为可选加速,WASM SIMD 仍然是基线方案。
五、总结
WASM AI 模型打包的核心原则:轻量模型优先、量化压缩体积、SIMD 加速基线、缓存优化加载。落地路径:
- 模型选择:选择 <30MB 的轻量模型(MobileBERT、SqueezeNet、量化后的小型 LLM),超过 50MB 的模型不适合浏览器端部署。
- 模型优化:INT8 量化减少模型体积 4 倍,算子融合减少计算量,剪枝移除冗余参数。
- 推理引擎:用 ONNX Runtime Web 快速集成,或用 Rust 编写自定义引擎编译为 WASM。
- 加载优化:IndexedDB 缓存模型和 WASM 模块,按需加载,首次加载时显示进度条。
WASM AI 推理不是要替代服务端推理,而是为特定场景(隐私敏感、低延迟、离线使用)提供补充方案。理解 WASM 的能力边界,在合适的场景中使用,才能发挥其最大价值。
686

被折叠的 条评论
为什么被折叠?



