Web AI 与端侧推理
问题
前端如何在浏览器中直接运行 AI 模型?WebGPU、WebNN、TensorFlow.js、ONNX Runtime Web、Transformers.js、WebLLM、MediaPipe 等技术的原理、适用场景和最佳实践是什么?如何优化模型加载、推理性能和用户体验?
答案
Web AI(端侧推理)是在浏览器中直接运行 AI 模型,不需要将数据发送到服务器。这带来了隐私保护、离线可用、低延迟等核心优势。随着 WebGPU、WebNN 等底层加速 API 的成熟,以及 Transformers.js、ONNX Runtime Web 等高层框架的发展,浏览器的 AI 能力正在快速提升。
一、Web AI 技术栈全景
| 技术 | 说明 | 加速后端 | 状态 |
|---|---|---|---|
| WebGPU | 现代 GPU 加速 API,原生支持 Compute Shader | GPU | Chrome 113+、Edge 113+、Firefox Nightly |
| WebNN | 直接访问 NPU/AI 加速器的标准 API | NPU/GPU/CPU | Chrome 开发中(Origin Trial) |
| WebAssembly | 接近原生的 CPU 执行性能 | CPU | 所有现代浏览器 |
| TensorFlow.js | Google 的 Web ML 框架,支持训练和推理 | WebGPU/WebGL/WASM | 成熟,生态丰富 |
| ONNX Runtime Web | 微软跨平台推理引擎,支持多种模型格式 | WebGPU/WebNN/WASM | 生产就绪 |
| Transformers.js | 直接运行 Hugging Face 上的模型 | WebGPU/WASM | v3 活跃发展中 |
| WebLLM | 浏览器运行 LLM(Llama、Mistral、Phi) | WebGPU | 实验性,快速迭代 |
| MediaPipe | Google 实时视觉/手势/人脸检测方案 | WebGPU/WASM/GPU delegate | 生产就绪 |
二、WebGPU 深入 — Compute Shader 与 AI 加速
WebGPU 是 WebGL 的继任者,基于 Vulkan/Metal/D3D12 设计。与 WebGL 最大的区别在于 WebGPU 原生支持 Compute Shader(计算着色器),可以直接进行通用计算(GPGPU),而不需要把计算映射为图形渲染操作。
Compute Shader 基础 — 矩阵乘法示例
矩阵乘法是神经网络中最核心的运算。以下展示如何用 WebGPU Compute Shader 实现矩阵乘法:
// WebGPU 矩阵乘法:C = A × B
async function gpuMatMul(
a: Float32Array,
b: Float32Array,
M: number, // A 的行数
K: number, // A 的列数 = B 的行数
N: number // B 的列数
): Promise<Float32Array> {
// 1. 获取 GPU 设备
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) throw new Error('WebGPU 不可用');
const device = await adapter.requestDevice();
// 2. 创建 GPU Buffer
const bufferA = device.createBuffer({
size: a.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferB = device.createBuffer({
size: b.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferC = device.createBuffer({
size: M * N * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
const bufferDims = device.createBuffer({
size: 3 * 4,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
// 写入数据
device.queue.writeBuffer(bufferA, 0, a);
device.queue.writeBuffer(bufferB, 0, b);
device.queue.writeBuffer(bufferDims, 0, new Uint32Array([M, K, N]));
// 3. 编写 WGSL Compute Shader
const shaderModule = device.createShaderModule({
code: `
struct Dims {
M: u32,
K: u32,
N: u32,
}
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@group(0) @binding(3) var<uniform> dims: Dims;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let row = id.x;
let col = id.y;
if (row >= dims.M || col >= dims.N) { return; }
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
sum = sum + A[row * dims.K + k] * B[k * dims.N + col];
}
C[row * dims.N + col] = sum;
}
`,
});
// 4. 创建计算管线
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: { module: shaderModule, entryPoint: 'main' },
});
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: bufferA } },
{ binding: 1, resource: { buffer: bufferB } },
{ binding: 2, resource: { buffer: bufferC } },
{ binding: 3, resource: { buffer: bufferDims } },
],
});
// 5. 提交计算命令
const commandEncoder = device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(Math.ceil(M / 16), Math.ceil(N / 16));
passEncoder.end();
// 读回结果
const readBuffer = device.createBuffer({
size: M * N * 4,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});
commandEncoder.copyBufferToBuffer(bufferC, 0, readBuffer, 0, M * N * 4);
device.queue.submit([commandEncoder.finish()]);
await readBuffer.mapAsync(GPUMapMode.READ);
const result = new Float32Array(readBuffer.getMappedRange().slice(0));
readBuffer.unmap();
return result;
}
WebGPU vs WebGL vs WASM 性能基准
| 基准测试(矩阵 1024x1024) | WebGPU | WebGL | WASM (SIMD) | CPU (JS) |
|---|---|---|---|---|
| 矩阵乘法 | ~8ms | ~30ms | ~120ms | ~800ms |
| ResNet-50 推理 | ~15ms | ~60ms | ~200ms | ~1500ms |
| BERT-base 推理 | ~25ms | ~100ms | ~350ms | N/A |
| 加速比 (vs CPU) | ~50-100x | ~10-25x | ~5-8x | 1x |
以上数据来自 Chrome 团队和 ONNX Runtime 的公开基准测试,实际性能因硬件和模型差异会有变化。WebGPU 在大矩阵运算和 Transformer 类模型上优势尤其显著。
WebGPU 浏览器支持情况
| 浏览器 | 支持状态 | 备注 |
|---|---|---|
| Chrome / Edge | 113+ 正式支持 | 桌面端完整支持,Android Chrome 121+ |
| Firefox | Nightly 实验性 | 需手动启用 dom.webgpu.enabled |
| Safari | 预览版实验性 | WebKit 正在实现中 |
| iOS Safari | 暂不支持 | 受限于 WebKit 进度 |
WebGPU 目前在移动端支持仍然有限。生产环境需做好降级方案,检测 WebGPU 不可用时回退到 WASM 后端。
// 运行时检测 WebGPU 支持
async function checkWebGPUSupport(): Promise<{
supported: boolean;
adapterInfo?: GPUAdapterInfo;
limits?: GPUSupportedLimits;
}> {
if (!('gpu' in navigator)) {
return { supported: false };
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
return { supported: false };
}
const info = await adapter.requestAdapterInfo();
return {
supported: true,
adapterInfo: info,
limits: adapter.limits,
};
}
// 根据支持情况选择后端
async function selectBackend(): Promise<'webgpu' | 'wasm' | 'webgl'> {
const gpuSupport = await checkWebGPUSupport();
if (gpuSupport.supported) return 'webgpu';
// 检查 WebGL 2
const canvas = document.createElement('canvas');
const gl = canvas.getContext('webgl2');
if (gl) return 'webgl';
return 'wasm';
}
三、WebNN — 原生 AI 加速 API
WebNN(Web Neural Network API)是 W3C 正在制定的标准,允许 Web 应用直接访问设备上的 NPU(神经处理单元)、GPU 和 CPU,实现高效的神经网络推理。
WebNN 架构
WebNN 核心概念
// WebNN API 基础用法
async function webnnInference() {
// 1. 获取 ML Context(指定设备偏好)
const context = await navigator.ml.createContext({
deviceType: 'npu', // 'cpu' | 'gpu' | 'npu'
powerPreference: 'low-power', // 'default' | 'high-performance' | 'low-power'
});
// 2. 创建 GraphBuilder
const builder = new MLGraphBuilder(context);
// 3. 定义计算图(以简单的全连接层为例)
const inputDesc: MLOperandDescriptor = {
dataType: 'float32',
shape: [1, 784], // 28x28 图片展平
};
const input = builder.input('input', inputDesc);
// 权重和偏置(实际场景从模型文件加载)
const weightsData = new Float32Array(784 * 128); // 假设已填充
const weights = builder.constant(
{ dataType: 'float32', shape: [784, 128] },
weightsData
);
const biasData = new Float32Array(128);
const bias = builder.constant(
{ dataType: 'float32', shape: [128] },
biasData
);
// 4. 构建操作
const matmul = builder.matmul(input, weights);
const add = builder.add(matmul, bias);
const output = builder.relu(add);
// 5. 编译计算图
const graph = await builder.build({ output });
// 6. 执行推理
const inputBuffer = new Float32Array(784);
const outputBuffer = new Float32Array(128);
const results = await context.compute(graph, {
input: inputBuffer,
}, {
output: outputBuffer,
});
return results.output;
}
WebNN 支持的操作
| 操作类别 | 包含操作 |
|---|---|
| 元素级 | add、sub、mul、div、pow、abs、ceil、floor、exp、log、sigmoid、relu、tanh、softmax |
| 矩阵 | matmul、gemm |
| 卷积 | conv2d、convTranspose2d |
| 池化 | averagePool2d、maxPool2d、l2Pool2d |
| 归一化 | batchNormalization、layerNormalization、instanceNormalization |
| 变形 | reshape、transpose、concat、split、slice、pad、expand |
| 注意力 | 暂无原生 attention 操作,通过 matmul + softmax 组合 |
Chrome 正在通过 Origin Trial 方式逐步开放 WebNN API。在 Windows 上通过 DirectML 后端支持 NPU/GPU 加速,macOS 上通过 Core ML 后端支持 Apple Neural Engine。开发者可在 chrome://flags/#enable-web-machine-learning-neural-network-api 手动启用。
四、ONNX Runtime Web 深入
ONNX Runtime Web 是微软开发的跨平台推理引擎,支持 ONNX 格式模型。它是目前浏览器端最成熟的通用推理方案之一。
模型转换:从 PyTorch/TensorFlow 到 ONNX
// Node.js 脚本:PyTorch 模型导出为 ONNX(通常在 Python 中完成)
// 以下是 Python 伪代码的 TypeScript 注释说明
/*
# PyTorch → ONNX
import torch
import torch.onnx
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
'model.onnx',
opset_version=17, # 推荐 opset 17+
input_names=['input'],
output_names=['output'],
dynamic_axes={ # 支持动态 batch
'input': {0: 'batch'},
'output': {0: 'batch'},
},
)
# TensorFlow → ONNX(使用 tf2onnx)
# python -m tf2onnx.convert --saved-model ./saved_model --output model.onnx --opset 17
*/
// 在浏览器中使用 ONNX Runtime Web
import * as ort from 'onnxruntime-web';
// 配置执行提供者
ort.env.wasm.wasmPaths = '/wasm/';
async function createSession(modelPath: string): Promise<ort.InferenceSession> {
const options: ort.InferenceSession.SessionOptions = {
executionProviders: [
'webgpu', // 首选 WebGPU
'wasm', // 降级到 WASM
],
graphOptimizationLevel: 'all',
enableCpuMemArena: true,
};
return ort.InferenceSession.create(modelPath, options);
}
// 完整的推理流程
async function runInference(
session: ort.InferenceSession,
inputData: Float32Array,
inputShape: number[]
): Promise<Float32Array> {
// 创建输入 Tensor
const inputTensor = new ort.Tensor('float32', inputData, inputShape);
// 获取输入输出名称
const inputName = session.inputNames[0];
const outputName = session.outputNames[0];
// 执行推理
const results = await session.run({ [inputName]: inputTensor });
return results[outputName].data as Float32Array;
}
ONNX 模型量化
量化是将模型权重从 FP32 转换为更低精度(INT8/INT4/FP16)的过程,可以显著减小模型体积和提升推理速度:
/*
模型量化通常在 Python 中完成:
# INT8 动态量化(推荐,最简单)
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input='model.onnx',
model_output='model_int8.onnx',
weight_type=QuantType.QInt8,
)
# INT4 量化(更小模型,适合 LLM 权重)
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input='model.onnx',
model_output='model_int4.onnx',
weight_type=QuantType.QInt4,
)
# FP16 量化(保留更多精度)
from onnxruntime.transformers import float16
import onnx
model = onnx.load('model.onnx')
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, 'model_fp16.onnx')
*/
// 量化效果对比
interface QuantizationComparison {
type: string;
sizeReduction: string;
speedup: string;
accuracyLoss: string;
useCase: string;
}
const QUANTIZATION_COMPARISON: QuantizationComparison[] = [
{
type: 'FP32(原始)',
sizeReduction: '基准',
speedup: '基准',
accuracyLoss: '无',
useCase: '训练、高精度需求',
},
{
type: 'FP16',
sizeReduction: '~50%',
speedup: '~1.5-2x',
accuracyLoss: '极小(<0.1%)',
useCase: 'GPU 推理、通用场景',
},
{
type: 'INT8',
sizeReduction: '~75%',
speedup: '~2-4x',
accuracyLoss: '小(<1%)',
useCase: '浏览器端推理(推荐)',
},
{
type: 'INT4',
sizeReduction: '~87.5%',
speedup: '~3-6x',
accuracyLoss: '中(1-3%)',
useCase: 'LLM 权重量化、极小模型',
},
];
| 量化类型 | 体积缩减 | 速度提升 | 精度损失 | 适用场景 |
|---|---|---|---|---|
| FP32(原始) | 基准 | 基准 | 无 | 训练、高精度需求 |
| FP16 | ~50% | ~1.5-2x | 极小(<0.1%) | GPU 推理、通用场景 |
| INT8 | ~75% | ~2-4x | 小(<1%) | 浏览器端推理(推荐) |
| INT4 | ~87.5% | ~3-6x | 中(1-3%) | LLM 权重量化、极小模型 |
ONNX Runtime Web 执行提供者(Execution Providers)
import * as ort from 'onnxruntime-web';
// 不同执行提供者的特点和选择策略
interface ExecutionProvider {
name: string;
backend: string;
pros: string[];
cons: string[];
}
const PROVIDERS: ExecutionProvider[] = [
{
name: 'webgpu',
backend: 'GPU (WebGPU API)',
pros: ['最快的 GPU 加速', '支持 Compute Shader', '内存带宽大'],
cons: ['需要 Chrome 113+', '移动端支持有限'],
},
{
name: 'webnn',
backend: 'NPU/GPU/CPU (WebNN API)',
pros: ['可利用 NPU 加速', '低功耗', '未来标准方向'],
cons: ['浏览器支持极早期', '操作覆盖不完整'],
},
{
name: 'wasm',
backend: 'CPU (WebAssembly)',
pros: ['兼容性最好', '所有浏览器支持', 'SIMD 加速'],
cons: ['性能不如 GPU', '大模型推理较慢'],
},
{
name: 'webgl',
backend: 'GPU (WebGL 2)',
pros: ['广泛的 GPU 支持', '兼容旧浏览器'],
cons: ['计算需映射为纹理操作', '比 WebGPU 慢 3-5x'],
},
];
// 智能选择执行提供者
async function getOptimalProviders(): Promise<string[]> {
const providers: string[] = [];
// 优先 WebGPU
if ('gpu' in navigator) {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) providers.push('webgpu');
}
// 其次 WebNN
if ('ml' in navigator) {
providers.push('webnn');
}
// 兜底 WASM(始终可用)
providers.push('wasm');
return providers;
}
五、Transformers.js 深入
Transformers.js 是 Hugging Face 官方的 JavaScript 版本,让你可以直接在浏览器中运行 Hugging Face 上的数千个模型。v3 版本支持 WebGPU 加速。
Pipeline API 全任务支持
import { pipeline, env, Pipeline } from '@huggingface/transformers';
// 配置(v3 使用 @huggingface/transformers 包名)
env.allowLocalModels = false;
// ===== 自然语言处理任务 =====
// 1. 文本分类(情感分析)
async function textClassification(text: string) {
const classifier = await pipeline(
'text-classification',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english'
);
return classifier(text);
// [{ label: 'POSITIVE', score: 0.9998 }]
}
// 2. 命名实体识别(NER)
async function tokenClassification(text: string) {
const ner = await pipeline(
'token-classification',
'Xenova/bert-base-NER'
);
return ner(text);
// [{ entity: 'B-PER', score: 0.99, word: 'John', ... }]
}
// 3. 问答
async function questionAnswering(question: string, context: string) {
const qa = await pipeline(
'question-answering',
'Xenova/distilbert-base-cased-distilled-squad'
);
return qa({ question, context });
// { answer: 'Paris', score: 0.98, start: 12, end: 17 }
}
// 4. 文本摘要
async function summarization(text: string) {
const summarizer = await pipeline(
'summarization',
'Xenova/distilbart-cnn-6-6'
);
return summarizer(text, { max_length: 100, min_length: 30 });
// [{ summary_text: '摘要内容...' }]
}
// 5. 翻译
async function translation(text: string) {
const translator = await pipeline(
'translation',
'Xenova/opus-mt-zh-en' // 中文到英文
);
return translator(text);
// [{ translation_text: 'translated text...' }]
}
// 6. 文本生成
async function textGeneration(prompt: string) {
const generator = await pipeline(
'text2text-generation',
'Xenova/flan-t5-small'
);
return generator(prompt, { max_new_tokens: 100 });
}
// 7. 零样本分类(不需要训练的分类)
async function zeroShotClassification(text: string, labels: string[]) {
const classifier = await pipeline(
'zero-shot-classification',
'Xenova/nli-deberta-v3-xsmall'
);
return classifier(text, labels);
// { labels: ['技术', '体育', '娱乐'], scores: [0.92, 0.05, 0.03] }
}
// 8. 特征提取(文本嵌入向量)
async function featureExtraction(text: string) {
const embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2'
);
const output = await embedder(text, { pooling: 'mean', normalize: true });
return Array.from(output.data); // 384 维向量
}
// ===== 计算机视觉任务 =====
// 9. 图片分类
async function imageClassification(imageUrl: string) {
const classifier = await pipeline(
'image-classification',
'Xenova/vit-base-patch16-224'
);
return classifier(imageUrl);
// [{ label: 'golden retriever', score: 0.89 }]
}
// 10. 目标检测
async function objectDetection(imageUrl: string) {
const detector = await pipeline(
'object-detection',
'Xenova/detr-resnet-50'
);
return detector(imageUrl);
// [{ label: 'cat', score: 0.98, box: { xmin, ymin, xmax, ymax } }]
}
// 11. 图像分割
async function imageSegmentation(imageUrl: string) {
const segmenter = await pipeline(
'image-segmentation',
'Xenova/detr-resnet-50-panoptic'
);
return segmenter(imageUrl);
// [{ label: 'cat', score: 0.99, mask: RawImage }]
}
模型缓存机制
Transformers.js 默认使用浏览器的 Cache API 缓存已下载的模型文件。理解缓存机制对优化加载体验至关重要:
import { pipeline, env } from '@huggingface/transformers';
// 缓存配置
env.useBrowserCache = true; // 默认开启 Cache API 缓存
env.allowLocalModels = false; // 是否从本地加载
// 自定义缓存路径
env.cacheDir = '/models'; // 自定义缓存目录名
// 手动管理模型缓存
class ModelCacheManager {
private cacheName = 'transformers-cache';
// 查看已缓存的模型
async listCachedModels(): Promise<string[]> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
return keys.map(req => req.url);
}
// 获取缓存大小
async getCacheSize(): Promise<number> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
let totalSize = 0;
for (const key of keys) {
const response = await cache.match(key);
if (response) {
const blob = await response.blob();
totalSize += blob.size;
}
}
return totalSize;
}
// 清理指定模型缓存
async clearModelCache(modelName: string): Promise<void> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
for (const key of keys) {
if (key.url.includes(modelName)) {
await cache.delete(key);
}
}
}
// 预加载模型(提前下载不立即使用)
async preloadModel(
task: string,
modelId: string,
onProgress?: (progress: number) => void
): Promise<void> {
await pipeline(task as any, modelId, {
progress_callback: (data: { status: string; progress?: number }) => {
if (data.status === 'progress' && data.progress !== undefined) {
onProgress?.(data.progress);
}
},
});
}
}
六、WebLLM 深入 — 浏览器中运行 LLM
WebLLM 基于 Apache TVM 的 MLC(Machine Learning Compilation)技术,将 LLM 编译为高效的 WebGPU 代码,在浏览器中运行。
支持的模型及资源需求
| 模型 | 量化 | 模型大小 | VRAM 需求 | 大约速度 |
|---|---|---|---|---|
| Llama-3.1-8B | q4f16_1 | ~4.5 GB | ~6 GB | 25-45 tok/s |
| Llama-3.2-3B | q4f16_1 | ~1.8 GB | ~3 GB | 40-70 tok/s |
| Mistral-7B | q4f16_1 | ~4.0 GB | ~5.5 GB | 25-45 tok/s |
| Phi-3.5-mini-3.8B | q4f16_1 | ~2.2 GB | ~3.5 GB | 35-60 tok/s |
| Qwen2.5-1.5B | q4f16_1 | ~1.0 GB | ~2 GB | 50-80 tok/s |
| SmolLM2-1.7B | q4f16_1 | ~1.0 GB | ~2 GB | 50-80 tok/s |
| Gemma-2-2B | q4f16_1 | ~1.5 GB | ~2.5 GB | 40-65 tok/s |
以上速度为在搭载 NVIDIA RTX 3060 (12GB) 或 Apple M1 Pro 等中高端 GPU 上的大致范围。实际速度取决于 GPU 型号、显存带宽、浏览器版本。低端 GPU 速度可能减半。
完整的 WebLLM 集成
import * as webllm from '@mlc-ai/web-llm';
interface WebLLMConfig {
modelId: string;
onProgress?: (progress: { text: string; progress: number }) => void;
onReady?: () => void;
}
class WebLLMEngine {
private engine: webllm.MLCEngine | null = null;
private isReady = false;
async init(config: WebLLMConfig): Promise<void> {
this.engine = await webllm.CreateMLCEngine(config.modelId, {
initProgressCallback: (progress) => {
config.onProgress?.({
text: progress.text,
progress: progress.progress,
});
},
});
this.isReady = true;
config.onReady?.();
}
// OpenAI 兼容的聊天接口
async chat(
messages: Array<{ role: string; content: string }>,
options?: { temperature?: number; max_tokens?: number }
): Promise<string> {
if (!this.engine) throw new Error('引擎未初始化');
const response = await this.engine.chat.completions.create({
messages: messages as webllm.ChatCompletionMessageParam[],
temperature: options?.temperature ?? 0.7,
max_tokens: options?.max_tokens ?? 1024,
});
return response.choices[0]?.message?.content ?? '';
}
// 流式聊天
async *chatStream(
messages: Array<{ role: string; content: string }>,
options?: { temperature?: number; max_tokens?: number }
): AsyncGenerator<string> {
if (!this.engine) throw new Error('引擎未初始化');
const response = await this.engine.chat.completions.create({
messages: messages as webllm.ChatCompletionMessageParam[],
temperature: options?.temperature ?? 0.7,
max_tokens: options?.max_tokens ?? 1024,
stream: true,
});
for await (const chunk of response) {
const content = chunk.choices[0]?.delta?.content;
if (content) yield content;
}
}
// 获取性能统计
async getStats(): Promise<{ tokensPerSecond: number; totalTokens: number }> {
if (!this.engine) throw new Error('引擎未初始化');
const stats = await this.engine.runtimeStatsText();
// 解析统计文本
return { tokensPerSecond: 0, totalTokens: 0 }; // 简化示例
}
dispose(): void {
this.engine = null;
this.isReady = false;
}
}
import { useState, useRef, useCallback } from 'react';
interface Message {
role: 'user' | 'assistant';
content: string;
}
export function LocalLLMChat() {
const [messages, setMessages] = useState<Message[]>([]);
const [input, setInput] = useState('');
const [loading, setLoading] = useState(false);
const [progress, setProgress] = useState<{ text: string; percent: number } | null>(null);
const engineRef = useRef<WebLLMEngine | null>(null);
const initEngine = useCallback(async () => {
const engine = new WebLLMEngine();
await engine.init({
modelId: 'Llama-3.2-3B-Instruct-q4f16_1-MLC',
onProgress: (p) => setProgress({ text: p.text, percent: p.progress * 100 }),
onReady: () => setProgress(null),
});
engineRef.current = engine;
}, []);
const sendMessage = useCallback(async () => {
if (!engineRef.current || !input.trim()) return;
const userMsg: Message = { role: 'user', content: input };
const newMessages = [...messages, userMsg];
setMessages(newMessages);
setInput('');
setLoading(true);
// 流式生成回复
let assistantContent = '';
setMessages([...newMessages, { role: 'assistant', content: '' }]);
for await (const chunk of engineRef.current.chatStream(newMessages)) {
assistantContent += chunk;
setMessages([
...newMessages,
{ role: 'assistant', content: assistantContent },
]);
}
setLoading(false);
}, [input, messages]);
return (
<div className="flex flex-col h-screen">
{/* 加载进度 */}
{progress && (
<div className="p-4 bg-blue-50">
<p className="text-sm">{progress.text}</p>
<div className="w-full bg-gray-200 rounded h-2 mt-1">
<div
className="bg-blue-500 h-2 rounded transition-all"
style={{ width: `${progress.percent}%` }}
/>
</div>
</div>
)}
{/* 消息列表 */}
<div className="flex-1 overflow-auto p-4">
{messages.map((msg, i) => (
<div key={i} className={`mb-4 ${msg.role === 'user' ? 'text-right' : ''}`}>
<div className={`inline-block p-3 rounded-lg ${
msg.role === 'user' ? 'bg-blue-500 text-white' : 'bg-gray-100'
}`}>
{msg.content}
</div>
</div>
))}
</div>
{/* 输入区 */}
<div className="p-4 border-t flex gap-2">
{!engineRef.current ? (
<button onClick={initEngine} className="btn-primary">
加载模型(~1.8GB)
</button>
) : (
<>
<input
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => e.key === 'Enter' && sendMessage()}
className="flex-1 border rounded px-3"
placeholder="输入消息..."
/>
<button onClick={sendMessage} disabled={loading}>
{loading ? '生成中...' : '发送'}
</button>
</>
)}
</div>
</div>
);
}
WebLLM 需要 WebGPU 支持且模型下载较大(1-4GB)。首次加载需要较长时间用于下载和编译模型。适合对隐私要求极高或需离线使用的场景。建议在用户明确触发后才开始加载模型,并配合进度条提供反馈。
七、MediaPipe 实时视觉处理
MediaPipe 是 Google 开发的跨平台机器学习框架,专注于实时视觉处理任务,包括手势识别、人脸网格、姿态检测、目标检测等。其模型轻量级,推理速度极快,适合在浏览器中运行。
MediaPipe 支持的视觉任务
| 任务 | 模型 | 输出 | 典型帧率 |
|---|---|---|---|
| 手势识别 | Hand Landmarker | 21 个手部关键点 + 手势分类 | 30+ FPS |
| 人脸网格 | Face Landmarker | 478 个面部关键点 + 表情分类 | 30+ FPS |
| 姿态检测 | Pose Landmarker | 33 个身体关键点 | 30+ FPS |
| 目标检测 | Object Detector | 边界框 + 类别 + 置信度 | 25+ FPS |
| 图像分类 | Image Classifier | 类别 + 置信度 | 30+ FPS |
| 图像分割 | Image Segmenter | 像素级分割掩码 | 20+ FPS |
| 文字识别 | Text Recognizer | OCR 文字 | 20+ FPS |
手势识别 React 组件
import { useEffect, useRef, useState, useCallback } from 'react';
import {
GestureRecognizer,
FilesetResolver,
GestureRecognizerResult,
} from '@mediapipe/tasks-vision';
interface DetectedGesture {
gesture: string;
confidence: number;
landmarks: Array<{ x: number; y: number; z: number }>;
}
export function HandGestureDetector() {
const videoRef = useRef<HTMLVideoElement>(null);
const canvasRef = useRef<HTMLCanvasElement>(null);
const [recognizer, setRecognizer] = useState<GestureRecognizer | null>(null);
const [gestures, setGestures] = useState<DetectedGesture[]>([]);
const [fps, setFps] = useState(0);
const frameCountRef = useRef(0);
const lastTimeRef = useRef(performance.now());
// 初始化手势识别器
useEffect(() => {
async function init() {
const vision = await FilesetResolver.forVisionTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
);
const gestureRecognizer = await GestureRecognizer.createFromOptions(vision, {
baseOptions: {
modelAssetPath:
'https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task',
delegate: 'GPU', // 使用 GPU 加速
},
runningMode: 'VIDEO',
numHands: 2,
});
setRecognizer(gestureRecognizer);
}
init();
}, []);
// 开启摄像头
useEffect(() => {
if (!videoRef.current) return;
navigator.mediaDevices
.getUserMedia({ video: { width: 640, height: 480, facingMode: 'user' } })
.then((stream) => {
videoRef.current!.srcObject = stream;
});
}, []);
// 实时检测循环
const detect = useCallback(() => {
if (!recognizer || !videoRef.current || !canvasRef.current) return;
const video = videoRef.current;
const canvas = canvasRef.current;
const ctx = canvas.getContext('2d')!;
const processFrame = () => {
if (video.readyState < 2) {
requestAnimationFrame(processFrame);
return;
}
// 执行手势识别
const result = recognizer.recognizeForVideo(video, performance.now());
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(video, 0, 0);
// 绘制手部关键点和连线
drawHandLandmarks(ctx, result);
// 更新识别结果
const detected: DetectedGesture[] = [];
if (result.gestures.length > 0) {
result.gestures.forEach((gestureList, handIndex) => {
const topGesture = gestureList[0];
detected.push({
gesture: topGesture.categoryName,
confidence: topGesture.score,
landmarks: result.landmarks[handIndex],
});
});
}
setGestures(detected);
// 计算 FPS
frameCountRef.current++;
const now = performance.now();
if (now - lastTimeRef.current >= 1000) {
setFps(frameCountRef.current);
frameCountRef.current = 0;
lastTimeRef.current = now;
}
requestAnimationFrame(processFrame);
};
processFrame();
}, [recognizer]);
useEffect(() => {
detect();
}, [detect]);
return (
<div className="relative">
<video ref={videoRef} autoPlay muted playsInline className="hidden" />
<canvas ref={canvasRef} width={640} height={480} />
{/* 识别结果 */}
<div className="absolute top-2 left-2 bg-black/70 text-white p-2 rounded">
<p>FPS: {fps}</p>
{gestures.map((g, i) => (
<p key={i}>
手 {i + 1}: {g.gesture}({(g.confidence * 100).toFixed(1)}%)
</p>
))}
</div>
</div>
);
}
// 绘制手部关键点
function drawHandLandmarks(
ctx: CanvasRenderingContext2D,
result: GestureRecognizerResult
): void {
const { landmarks } = result;
if (!landmarks.length) return;
// 手部连线定义(简化)
const connections = [
[0, 1], [1, 2], [2, 3], [3, 4], // 拇指
[0, 5], [5, 6], [6, 7], [7, 8], // 食指
[0, 9], [9, 10], [10, 11], [11, 12], // 中指
[0, 13], [13, 14], [14, 15], [15, 16], // 无名指
[0, 17], [17, 18], [18, 19], [19, 20], // 小指
[5, 9], [9, 13], [13, 17], // 手掌
];
for (const handLandmarks of landmarks) {
// 画连线
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
for (const [start, end] of connections) {
const p1 = handLandmarks[start];
const p2 = handLandmarks[end];
ctx.beginPath();
ctx.moveTo(p1.x * ctx.canvas.width, p1.y * ctx.canvas.height);
ctx.lineTo(p2.x * ctx.canvas.width, p2.y * ctx.canvas.height);
ctx.stroke();
}
// 画关键点
ctx.fillStyle = '#FF0000';
for (const point of handLandmarks) {
ctx.beginPath();
ctx.arc(
point.x * ctx.canvas.width,
point.y * ctx.canvas.height,
4, 0, 2 * Math.PI
);
ctx.fill();
}
}
}
人脸网格检测
import {
FaceLandmarker,
FilesetResolver,
FaceLandmarkerResult,
} from '@mediapipe/tasks-vision';
async function createFaceDetector(): Promise<FaceLandmarker> {
const vision = await FilesetResolver.forVisionTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
);
return FaceLandmarker.createFromOptions(vision, {
baseOptions: {
modelAssetPath:
'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task',
delegate: 'GPU',
},
runningMode: 'VIDEO',
numFaces: 1,
outputFaceBlendshapes: true, // 输出表情混合形状(52 个表情参数)
outputFacialTransformationMatrixes: true, // 输出面部变换矩阵
});
}
// 表情分类(基于 blendshapes)
function classifyExpression(
blendshapes: Array<{ categoryName: string; score: number }>
): string {
const shapes: Record<string, number> = {};
for (const bs of blendshapes) {
shapes[bs.categoryName] = bs.score;
}
// 基于关键 blendshapes 判断表情
if (shapes['mouthSmileLeft'] > 0.5 && shapes['mouthSmileRight'] > 0.5) {
return '微笑';
}
if (shapes['browDownLeft'] > 0.5 && shapes['browDownRight'] > 0.5) {
return '皱眉';
}
if (shapes['jawOpen'] > 0.5) {
return '张嘴';
}
if (shapes['eyeBlinkLeft'] > 0.5 && shapes['eyeBlinkRight'] > 0.5) {
return '眨眼';
}
return '正常';
}
八、模型优化策略
在浏览器中运行 AI 模型,模型体积和推理速度是关键瓶颈。以下是主要的优化策略:
1. 量化(Quantization)
| 量化方式 | 说明 | 体积变化 | 精度影响 |
|---|---|---|---|
| 动态量化 | 权重 INT8 + 激活运行时量化 | 减少 ~75% | 极小 |
| 静态量化 | 权重和激活都预先量化 | 减少 ~75% | 小,需校准数据 |
| 量化感知训练 (QAT) | 训练过程中模拟量化 | 减少 ~75% | 最小 |
| GPTQ/AWQ | LLM 专用权重量化 | 减少 ~75-87.5% | 中等 |
2. 模型剪枝(Pruning)
// 模型剪枝的概念说明(剪枝通常在 Python 训练端完成)
/*
模型剪枝分为:
1. 非结构化剪枝(Unstructured Pruning)
- 将小于阈值的权重置零
- 需要稀疏矩阵计算支持
- 压缩率高但硬件加速困难
2. 结构化剪枝(Structured Pruning)
- 移除整个神经元/通道/注意力头
- 直接减少计算量
- 对硬件友好
*/
// 浏览器端可以加载已剪枝的模型
// 剪枝后的模型体积更小、推理更快
interface ModelOptimizationPlan {
strategy: string;
originalSize: string;
optimizedSize: string;
speedup: string;
accuracyLoss: string;
}
const OPTIMIZATION_STRATEGIES: ModelOptimizationPlan[] = [
{
strategy: '仅量化(INT8)',
originalSize: '100 MB',
optimizedSize: '25 MB',
speedup: '2-4x',
accuracyLoss: '<1%',
},
{
strategy: '量化(INT4)+ 剪枝(50%)',
originalSize: '100 MB',
optimizedSize: '6 MB',
speedup: '4-8x',
accuracyLoss: '2-5%',
},
{
strategy: '知识蒸馏(小模型)',
originalSize: '400 MB (teacher)',
optimizedSize: '25 MB (student)',
speedup: '10-20x',
accuracyLoss: '3-8%',
},
{
strategy: '全套优化(蒸馏+量化+剪枝)',
originalSize: '400 MB',
optimizedSize: '3 MB',
speedup: '50-100x',
accuracyLoss: '5-15%',
},
];
3. 知识蒸馏(Knowledge Distillation)
- 文本分类/情感分析:DistilBERT(66MB)足够,无需 BERT-base
- 目标检测:MobileNet-SSD(~10MB)> YOLO-v8n(~12MB)> DETR(~160MB)
- 文本嵌入:all-MiniLM-L6-v2(~23MB)是体积和质量的最佳平衡
- LLM:Qwen2.5-1.5B-q4(~1GB)是当前浏览器中运行 LLM 的实用下限
九、渐进式模型加载
模型文件通常较大(10MB - 4GB),需要精心设计加载策略以提升用户体验。
// 渐进式模型加载器 — 带进度、缓存、预热
class ProgressiveModelLoader {
private cacheStoreName = 'ai-models';
// 带进度的模型下载
async downloadWithProgress(
url: string,
onProgress: (loaded: number, total: number) => void
): Promise<ArrayBuffer> {
const response = await fetch(url);
const contentLength = Number(response.headers.get('content-length')) || 0;
const reader = response.body!.getReader();
const chunks: Uint8Array[] = [];
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
onProgress(loaded, contentLength);
}
// 合并所有 chunks
const buffer = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
buffer.set(chunk, offset);
offset += chunk.length;
}
return buffer.buffer;
}
// 使用 Cache API 缓存模型
async loadWithCache(
modelUrl: string,
onProgress: (stage: string, progress: number) => void
): Promise<ArrayBuffer> {
const cache = await caches.open(this.cacheStoreName);
const cached = await cache.match(modelUrl);
if (cached) {
onProgress('从缓存加载', 1);
return cached.arrayBuffer();
}
onProgress('下载模型', 0);
const buffer = await this.downloadWithProgress(modelUrl, (loaded, total) => {
onProgress('下载模型', total > 0 ? loaded / total : 0);
});
// 存入缓存
await cache.put(modelUrl, new Response(buffer.slice(0)));
onProgress('缓存完成', 1);
return buffer;
}
// 使用 IndexedDB 缓存(适合大模型,不受 Cache API 大小限制)
async loadWithIndexedDB(
modelUrl: string,
modelId: string,
onProgress: (stage: string, progress: number) => void
): Promise<ArrayBuffer> {
// 先检查 IndexedDB
const cachedBuffer = await this.getFromIDB(modelId);
if (cachedBuffer) {
onProgress('从 IndexedDB 加载', 1);
return cachedBuffer;
}
onProgress('下载模型', 0);
const buffer = await this.downloadWithProgress(modelUrl, (loaded, total) => {
onProgress('下载模型', total > 0 ? loaded / total : 0);
});
// 存入 IndexedDB
await this.saveToIDB(modelId, buffer);
onProgress('缓存完成', 1);
return buffer;
}
private getFromIDB(key: string): Promise<ArrayBuffer | null> {
return new Promise((resolve, reject) => {
const request = indexedDB.open('ModelCache', 1);
request.onupgradeneeded = () => {
request.result.createObjectStore('models');
};
request.onsuccess = () => {
const tx = request.result.transaction('models', 'readonly');
const store = tx.objectStore('models');
const getReq = store.get(key);
getReq.onsuccess = () => resolve(getReq.result ?? null);
getReq.onerror = () => reject(getReq.error);
};
});
}
private saveToIDB(key: string, data: ArrayBuffer): Promise<void> {
return new Promise((resolve, reject) => {
const request = indexedDB.open('ModelCache', 1);
request.onupgradeneeded = () => {
request.result.createObjectStore('models');
};
request.onsuccess = () => {
const tx = request.result.transaction('models', 'readwrite');
const store = tx.objectStore('models');
store.put(data, key);
tx.oncomplete = () => resolve();
tx.onerror = () => reject(tx.error);
};
});
}
}
加载进度 UI 组件
import { useState, useCallback } from 'react';
interface LoadingState {
stage: string;
progress: number;
error?: string;
}
export function ModelLoadingUI({
onLoaded,
}: {
onLoaded: (buffer: ArrayBuffer) => void;
}) {
const [state, setState] = useState<LoadingState | null>(null);
const loader = new ProgressiveModelLoader();
const startLoading = useCallback(async (modelUrl: string) => {
try {
const buffer = await loader.loadWithCache(modelUrl, (stage, progress) => {
setState({ stage, progress });
});
// 模型预热(首次推理通常较慢,先运行一次空推理)
setState({ stage: '模型预热中...', progress: 1 });
onLoaded(buffer);
} catch (error) {
setState({
stage: '加载失败',
progress: 0,
error: (error as Error).message,
});
}
}, [onLoaded]);
if (!state) {
return (
<button onClick={() => startLoading('/models/model.onnx')}>
加载 AI 模型
</button>
);
}
return (
<div className="w-full max-w-md mx-auto p-4">
<p className="text-sm text-gray-600 mb-2">{state.stage}</p>
<div className="w-full bg-gray-200 rounded-full h-3">
<div
className="bg-blue-500 h-3 rounded-full transition-all duration-300"
style={{ width: `${state.progress * 100}%` }}
/>
</div>
<p className="text-xs text-gray-400 mt-1">
{(state.progress * 100).toFixed(1)}%
</p>
{state.error && (
<p className="text-red-500 text-sm mt-2">{state.error}</p>
)}
</div>
);
}
模型预热策略
// 模型预热:首次推理通常比后续慢 2-5 倍(JIT 编译、GPU 管线初始化等)
// 在后台预先运行一次推理可以消除这个延迟
async function warmupONNXModel(session: ort.InferenceSession): Promise<void> {
const inputName = session.inputNames[0];
const inputShape = session.inputNames.length > 0
? (session as any)._model?.graph?.input?.[0]?.type?.tensorType?.shape?.dim?.map(
(d: any) => d.dimValue || 1
) ?? [1, 3, 224, 224]
: [1, 3, 224, 224];
// 创建全零的 dummy 输入
const dummyData = new Float32Array(
inputShape.reduce((a: number, b: number) => a * b, 1)
);
const dummyTensor = new ort.Tensor('float32', dummyData, inputShape);
// 运行一次推理(丢弃结果)
await session.run({ [inputName]: dummyTensor });
}
十、混合推理架构
混合推理(Hybrid Inference)将端侧推理和云端推理结合,利用各自优势:
import { pipeline } from '@huggingface/transformers';
interface InferenceResult {
source: 'local' | 'cloud';
result: unknown;
latency: number;
confidence?: number;
}
class HybridInferenceEngine {
private localClassifier: Awaited<ReturnType<typeof pipeline>> | null = null;
private localEmbedder: Awaited<ReturnType<typeof pipeline>> | null = null;
private confidenceThreshold = 0.85;
async init(): Promise<void> {
// 并行加载本地模型
const [classifier, embedder] = await Promise.all([
pipeline('text-classification', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'),
pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'),
]);
this.localClassifier = classifier;
this.localEmbedder = embedder;
}
// 策略1:置信度路由 — 本地推理置信度低则转云端
async classifyWithFallback(text: string): Promise<InferenceResult> {
const start = performance.now();
// 先尝试本地推理
if (this.localClassifier) {
const localResult = await this.localClassifier(text) as any[];
const topResult = localResult[0];
if (topResult.score >= this.confidenceThreshold) {
// 置信度足够,直接返回本地结果
return {
source: 'local',
result: topResult,
latency: performance.now() - start,
confidence: topResult.score,
};
}
}
// 本地置信度不足或模型未加载,回退到云端
const cloudResult = await this.callCloudAPI(text);
return {
source: 'cloud',
result: cloudResult,
latency: performance.now() - start,
};
}
// 策略2:本地预处理 + 云端精处理
async analyzeWithPreprocessing(text: string): Promise<InferenceResult> {
const start = performance.now();
// 本地生成文本嵌入(用于 RAG 检索)
let embedding: number[] | null = null;
if (this.localEmbedder) {
const output = await this.localEmbedder(text, {
pooling: 'mean',
normalize: true,
});
embedding = Array.from(output.data);
}
// 将嵌入向量发送给云端(用于相似度搜索/RAG)
const cloudResult = await this.callCloudAPIWithEmbedding(text, embedding);
return {
source: 'cloud',
result: cloudResult,
latency: performance.now() - start,
};
}
// 策略3:任务复杂度路由
async smartRoute(task: {
type: 'classify' | 'generate' | 'summarize' | 'embed';
input: string;
}): Promise<InferenceResult> {
const start = performance.now();
switch (task.type) {
case 'classify':
case 'embed':
// 分类和嵌入在本地完成
return this.classifyWithFallback(task.input);
case 'generate':
case 'summarize':
// 生成和摘要发送到云端
const result = await this.callCloudAPI(task.input);
return {
source: 'cloud',
result,
latency: performance.now() - start,
};
default:
throw new Error(`未知任务类型: ${task.type}`);
}
}
private async callCloudAPI(text: string): Promise<unknown> {
const response = await fetch('/api/ai/inference', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text }),
});
return response.json();
}
private async callCloudAPIWithEmbedding(
text: string,
embedding: number[] | null
): Promise<unknown> {
const response = await fetch('/api/ai/inference', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text, embedding }),
});
return response.json();
}
}
- 表单校验/文本分类:本地推理,低延迟,无成本
- 搜索建议:本地嵌入 + 本地向量搜索
- RAG 应用:本地嵌入生成 + 云端 LLM 生成回答
- 图片预筛:本地目标检测/分类 + 云端 Vision 精分析
- 敏感数据:本地脱敏/分类后再发送云端
十一、Web Worker 深度集成
AI 推理计算量大,必须放在 Web Worker 中避免阻塞主线程。以下是使用 SharedArrayBuffer 和 Comlink 的进阶方案。
使用 Comlink 简化 Worker 通信
Comlink 是 Google Chrome 团队开发的库,将 Worker 通信封装为简单的函数调用:
// Worker 端
import * as Comlink from 'comlink';
import { pipeline } from '@huggingface/transformers';
class AIWorkerAPI {
private classifier: Awaited<ReturnType<typeof pipeline>> | null = null;
private embedder: Awaited<ReturnType<typeof pipeline>> | null = null;
async initClassifier(
onProgress: (progress: number) => void
): Promise<void> {
this.classifier = await pipeline(
'text-classification',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english',
{
progress_callback: (data: { status: string; progress?: number }) => {
if (data.progress !== undefined) onProgress(data.progress);
},
}
);
}
async initEmbedder(): Promise<void> {
this.embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2'
);
}
async classify(text: string): Promise<Array<{ label: string; score: number }>> {
if (!this.classifier) throw new Error('分类器未初始化');
return this.classifier(text) as any;
}
async embed(text: string): Promise<Float32Array> {
if (!this.embedder) throw new Error('嵌入模型未初始化');
const output = await this.embedder(text, { pooling: 'mean', normalize: true });
return output.data as Float32Array;
}
// 批量处理(利用 SharedArrayBuffer 减少数据拷贝)
async batchEmbed(
texts: string[],
resultBuffer: SharedArrayBuffer // 共享内存
): Promise<void> {
if (!this.embedder) throw new Error('嵌入模型未初始化');
const resultView = new Float32Array(resultBuffer);
const embeddingDim = 384; // all-MiniLM-L6-v2 的维度
for (let i = 0; i < texts.length; i++) {
const output = await this.embedder(texts[i], {
pooling: 'mean',
normalize: true,
});
const embedding = output.data as Float32Array;
resultView.set(embedding, i * embeddingDim);
}
}
}
Comlink.expose(new AIWorkerAPI());
import { useEffect, useRef, useState, useCallback } from 'react';
import * as Comlink from 'comlink';
interface AIWorkerAPI {
initClassifier(onProgress: (progress: number) => void): Promise<void>;
initEmbedder(): Promise<void>;
classify(text: string): Promise<Array<{ label: string; score: number }>>;
embed(text: string): Promise<Float32Array>;
batchEmbed(texts: string[], resultBuffer: SharedArrayBuffer): Promise<void>;
}
export function useAIWorker() {
const workerRef = useRef<Worker | null>(null);
const apiRef = useRef<Comlink.Remote<AIWorkerAPI> | null>(null);
const [isReady, setIsReady] = useState(false);
const [loadProgress, setLoadProgress] = useState(0);
useEffect(() => {
const worker = new Worker(
new URL('../workers/ai-worker-comlink.ts', import.meta.url),
{ type: 'module' }
);
workerRef.current = worker;
apiRef.current = Comlink.wrap<AIWorkerAPI>(worker);
// 初始化模型
apiRef.current
.initClassifier(Comlink.proxy((progress: number) => {
setLoadProgress(progress);
}))
.then(() => setIsReady(true));
return () => worker.terminate();
}, []);
const classify = useCallback(async (text: string) => {
if (!apiRef.current) throw new Error('Worker 未就绪');
return apiRef.current.classify(text); // 像调用本地函数一样简单
}, []);
const batchEmbed = useCallback(async (texts: string[]) => {
if (!apiRef.current) throw new Error('Worker 未就绪');
const embeddingDim = 384;
// 使用 SharedArrayBuffer 避免大数据拷贝
const buffer = new SharedArrayBuffer(texts.length * embeddingDim * 4);
await apiRef.current.batchEmbed(texts, buffer);
return new Float32Array(buffer);
}, []);
return { isReady, loadProgress, classify, batchEmbed };
}
使用 SharedArrayBuffer 需要页面启用 COOP/COEP 安全头:
Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp
这些头可能导致部分第三方资源加载失败,需要在 CDN 资源上配置 Cross-Origin-Resource-Policy: cross-origin。
十二、端侧 vs 云端对比
| 维度 | 端侧推理 | 云端推理 |
|---|---|---|
| 隐私 | 数据不离开设备 | 数据需发送到服务器 |
| 延迟 | 低(无网络延迟) | 高(网络 + 排队) |
| 离线 | 支持 | 需要网络 |
| 模型能力 | 小模型(<10B 参数) | 大模型(100B+ 参数) |
| 费用 | 免费(使用用户算力) | 按 token 付费 |
| 首次加载 | 慢(需下载模型 10MB-4GB) | 快(仅发送请求) |
| 兼容性 | WebGPU 要求较新浏览器 | 全平台 |
| 维护 | 模型更新需用户重新下载 | 服务端透明更新 |
| 适用场景 | 实时检测、隐私数据、离线、高频低复杂度任务 | 复杂推理、长文本生成、RAG |
十三、TensorFlow.js
TensorFlow.js 是 Google 开发的最成熟的 Web ML 框架,拥有丰富的预训练模型生态:
import * as tf from '@tensorflow/tfjs';
// 1. 加载预训练模型
async function loadModel() {
const model = await tf.loadLayersModel('/models/sentiment/model.json');
return model;
}
// 2. 图片分类
import * as mobilenet from '@tensorflow-models/mobilenet';
async function classifyImage(imageElement: HTMLImageElement) {
const model = await mobilenet.load({ version: 2, alpha: 1.0 });
const predictions = await model.classify(imageElement);
return predictions.map(p => ({
className: p.className,
probability: (p.probability * 100).toFixed(1) + '%',
}));
// [{ className: '金毛猎犬', probability: '89.2%' }]
}
// 3. 姿态检测
import * as poseDetection from '@tensorflow-models/pose-detection';
async function detectPose(video: HTMLVideoElement) {
const detector = await poseDetection.createDetector(
poseDetection.SupportedModels.MoveNet,
{ modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }
);
const poses = await detector.estimatePoses(video);
return poses;
}
常见面试问题
Q1: 浏览器中运行 AI 模型有哪些方案?各自的特点是什么?
答案:
| 方案 | 开发者 | 加速后端 | 模型来源 | 适用场景 |
|---|---|---|---|---|
| TensorFlow.js | WebGPU/WebGL/WASM | TF 模型、TFJS 模型 | 训练+推理,生态最丰富 | |
| ONNX Runtime Web | Microsoft | WebGPU/WebNN/WASM | ONNX 格式(PyTorch/TF 转换) | 跨框架通用推理 |
| Transformers.js | Hugging Face | WebGPU/WASM | Hugging Face 模型 | NLP/CV 任务,开箱即用 |
| WebLLM | MLC-AI | WebGPU | Llama/Mistral/Phi 等 LLM | 浏览器中运行完整 LLM |
| MediaPipe | GPU delegate/WASM | 预训练视觉模型 | 实时视觉处理(手势/人脸/姿态) |
加速层优先级:WebGPU(GPU)> WebNN(NPU)> WebAssembly(CPU)。
选择建议:
- 快速原型:Transformers.js(pipeline API 最简单)
- 生产部署:ONNX Runtime Web(跨框架、执行提供者丰富)
- 实时视觉:MediaPipe(专为实时优化)
- 浏览器 LLM:WebLLM(唯一可行方案)
Q2: WebGPU Compute Shader 和 WebGL 在 AI 计算上有什么区别?
答案:
| 特性 | WebGPU Compute Shader | WebGL (用于 AI) |
|---|---|---|
| 设计目标 | 通用计算(GPGPU) | 图形渲染 |
| 计算模型 | 原生 Compute Shader,直接操作 buffer | 需要将计算映射为纹理渲染 |
| 数据类型 | f32, f16, i32, u32, 支持结构体 | 纹理像素(RGBA float) |
| 同步 | 原子操作、workgroup barrier | 无计算同步原语 |
| 内存访问 | Storage Buffer 直接读写 | 纹理采样,受限于纹理尺寸 |
| 性能 | 比 WebGL 快 3-10 倍(矩阵运算) | 受纹理映射开销限制 |
| API 复杂度 | 较高(类似 Vulkan) | 中等 |
WebGPU Compute Shader 的核心优势:
- 原生支持 GPGPU:不需要把矩阵乘法伪装为纹理操作
- workgroup 共享内存:线程组内共享数据,减少全局内存访问
- 灵活的数据布局:Storage Buffer 支持任意结构体
- 原子操作:支持线程安全的累加/比较等操作
Q3: 如何将 PyTorch/TensorFlow 模型转换并优化用于浏览器推理?
答案:
转换流程:
PyTorch 模型 (.pth)
→ torch.onnx.export() → model.onnx
→ 量化工具 → model_int8.onnx
→ 浏览器 ONNX Runtime Web 加载
TensorFlow 模型 (.pb / SavedModel)
→ tf2onnx → model.onnx → 同上
或
→ tensorflowjs_converter → TFJS 格式 → TensorFlow.js 加载
优化步骤:
- 模型转换:PyTorch 用
torch.onnx.export(),TensorFlow 用tf2onnx - 图优化:ONNX Runtime 的
GraphOptimizationLevel.ORT_ENABLE_ALL自动进行算子融合、常量折叠 - 量化:动态量化(INT8)最简单,
quantize_dynamic()一行代码 - 分片:大模型拆分为多个分片(每片 <50MB),支持并行下载
- 测试:量化后在目标硬件上验证精度损失是否可接受
关键配置:
// opset_version 选择 17+(支持更多操作)
// dynamic_axes 设为动态 batch(支持不同输入大小)
// 量化推荐 INT8 动态量化(最简单,精度损失最小)
Q4: 如何实现渐进式模型加载并展示进度?
答案:
渐进式模型加载包含四个阶段:
- 检查缓存:先查 Cache API 或 IndexedDB 是否有缓存
- 下载模型:使用
fetch+ReadableStream实现带进度的下载 - 缓存模型:下载后存入 Cache API(小模型)或 IndexedDB(大模型)
- 模型预热:首次推理较慢,后台运行空推理消除 JIT 编译延迟
核心实现:
// 带进度的下载
const response = await fetch(modelUrl);
const total = Number(response.headers.get('content-length'));
const reader = response.body!.getReader();
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
loaded += value.length;
onProgress(loaded / total); // 更新进度
}
缓存选择:
- Cache API:适合小模型(<200MB),API 简单,自动管理
- IndexedDB:适合大模型(>200MB),无大小限制,手动管理
- Transformers.js 默认使用 Cache API,WebLLM 使用 IndexedDB
Q5: 对比 TensorFlow.js、ONNX Runtime Web 和 Transformers.js
答案:
| 维度 | TensorFlow.js | ONNX Runtime Web | Transformers.js |
|---|---|---|---|
| 开发者 | Microsoft | Hugging Face | |
| 模型格式 | TF/TFJS | ONNX | ONNX(自动转换) |
| 加速后端 | WebGPU/WebGL/WASM | WebGPU/WebNN/WASM | WebGPU/WASM |
| 模型训练 | 支持 | 不支持 | 不支持 |
| 模型生态 | TF Hub | ONNX Model Zoo | Hugging Face Hub(最大) |
| API 易用性 | 中等 | 低(需手动 Tensor) | 高(pipeline API) |
| 模型加载 | 需转为 TFJS 格式 | 需转为 ONNX 格式 | 自动下载转换 |
| 包大小 | ~500KB | ~200KB | ~300KB(不含模型) |
| WebNN 支持 | 不支持 | 支持 | 不支持 |
| 社区活跃度 | 高 | 高 | 极高 |
选择建议:
- 需要浏览器端训练 → TensorFlow.js
- 已有 PyTorch/TF 模型需部署 → ONNX Runtime Web
- 快速使用 NLP/CV 模型 → Transformers.js
- 需要 WebNN/NPU 加速 → ONNX Runtime Web
Q6: 如何在浏览器中实现实时目标检测?
答案:
实时目标检测需要解决三个核心问题:模型选择、帧处理循环、性能优化。
// 使用 Transformers.js 的目标检测 pipeline
import { pipeline } from '@huggingface/transformers';
// 1. 加载模型(一次性)
const detector = await pipeline(
'object-detection',
'Xenova/detr-resnet-50'
);
// 2. 帧处理循环
function startDetection(video: HTMLVideoElement, canvas: HTMLCanvasElement) {
const ctx = canvas.getContext('2d')!;
async function processFrame() {
// 将视频帧绘制到 canvas
ctx.drawImage(video, 0, 0);
// 推理
const results = await detector(canvas.toDataURL(), {
threshold: 0.7, // 置信度阈值
});
// 绘制检测框
for (const { label, score, box } of results) {
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
ctx.strokeRect(box.xmin, box.ymin, box.xmax - box.xmin, box.ymax - box.ymin);
ctx.fillText(`${label} ${(score * 100).toFixed(0)}%`, box.xmin, box.ymin - 5);
}
requestAnimationFrame(processFrame);
}
processFrame();
}
性能优化要点:
- Worker 隔离:推理放在 Web Worker,避免阻塞 UI
- 帧率控制:不需要每帧都推理,可以每 2-3 帧推理一次
- 模型选择:DETR (~160MB) 精度高但慢,MobileNet-SSD (~10MB) 速度快但精度低
- 输入分辨率:降低输入图片分辨率可大幅提升速度
- WebGPU 加速:确保使用 WebGPU 后端而非 WASM
或使用 MediaPipe 获得更好的实时性能(专为实时优化,FPS 更高)。
Q7: 什么是混合推理?什么场景下应该使用?
答案:
混合推理是将端侧推理和云端推理结合的架构模式,根据任务复杂度、置信度、网络状态等因素动态选择推理位置。
三种路由策略:
| 策略 | 原理 | 示例 |
|---|---|---|
| 置信度路由 | 本地推理置信度低于阈值时转云端 | 情感分析置信度 <85% 转 GPT-4o |
| 任务复杂度路由 | 简单任务本地、复杂任务云端 | 分类/嵌入本地,生成/摘要云端 |
| 预处理+精处理 | 本地做数据预处理后发送云端 | 本地生成嵌入向量 → 云端 RAG |
适用场景:
- RAG 应用:本地嵌入 + 云端 LLM
- 图片审核:本地 NSFW 检测 + 可疑内容发云端人工审核
- 离线优先:有网络用云端,离线用本地小模型
- 成本优化:高频简单任务本地处理,减少 API 调用
Q8: 如何管理浏览器中的 AI 模型缓存(Cache API / IndexedDB)?
答案:
| 方案 | 适用场景 | 大小限制 | API 复杂度 |
|---|---|---|---|
| Cache API | 小模型(<200MB) | 浏览器配额(通常 >1GB) | 简单 |
| IndexedDB | 大模型(>200MB) | 更大配额 | 中等 |
| Origin Private File System | 超大模型 | 磁盘空间 | 较新 API |
关键实践:
// Cache API 缓存
const cache = await caches.open('ai-models');
const cached = await cache.match(modelUrl);
if (cached) return cached.arrayBuffer();
// 下载并缓存
const response = await fetch(modelUrl);
await cache.put(modelUrl, response.clone());
return response.arrayBuffer();
注意事项:
- 存储配额:使用
navigator.storage.estimate()检查剩余空间 - 版本管理:模型 URL 带版本号(如
model_v2_int8.onnx),更新时清理旧版本 - 用户提示:大模型下载前提示用户("即将下载 1.8GB 模型")
- 清理策略:定期清理不再使用的模型缓存
- 持久化存储:调用
navigator.storage.persist()防止浏览器自动清理
Q9: WebNN 和 WebGPU 有什么区别?各自在 AI 推理中的角色是什么?
答案:
| 维度 | WebGPU | WebNN |
|---|---|---|
| 设计目标 | 通用 GPU 编程(图形+计算) | 专用于神经网络推理 |
| 硬件访问 | GPU | NPU、GPU、CPU(统一接口) |
| API 粒度 | 低级(写 Shader、管理 Buffer) | 高级(定义计算图、执行推理) |
| 优化方式 | 手动(Shader 优化、内存管理) | 自动(OS 层面优化,利用专用硬件) |
| 功耗 | 较高(GPU 满载) | 较低(NPU 能效比高) |
| 浏览器支持 | Chrome 113+ 正式 | Chrome Origin Trial |
| 主要受益 | 所有需要 GPU 计算的任务 | 神经网络推理,尤其移动端 |
关系:WebNN 和 WebGPU 不是竞争关系而是互补。WebNN 通过操作系统的原生 ML 框架(DirectML/CoreML/NNAPI)利用 NPU 硬件,功耗更低;WebGPU 提供更灵活的 GPU 计算能力。ONNX Runtime Web 同时支持两者作为执行提供者。
未来趋势:移动设备上 NPU 普及(高通骁龙、Apple Neural Engine),WebNN 的价值将更加突出。
Q10: 端侧 AI 推理会阻塞主线程吗?如何优化?
答案:
会阻塞。AI 模型推理涉及大量计算,如果在主线程运行会导致页面卡顿和无响应。
解决方案(按优先级排序):
-
Web Worker(最重要):将模型加载和推理放在 Worker 线程
- 使用 Comlink 简化 Worker 通信(像调用本地函数一样)
- 大数据传输使用 Transferable Objects 或 SharedArrayBuffer 减少拷贝
-
WebGPU 加速:GPU 计算本身在 GPU 上执行,不阻塞 CPU 主线程
- 但 JS 层的数据准备和结果读取仍在主线程
- 仍建议配合 Worker 使用
-
模型量化:INT8/INT4 量化减少计算量,推理更快
-
帧率控制:实时检测不需要每帧推理,可每 2-3 帧推理一次
更多 Worker 使用细节参考 Web Workers。 更多性能优化策略参考 AI 应用性能优化。
Q11: Transformers.js 支持哪些 AI 任务?模型如何缓存?
答案:
支持的任务(使用 pipeline API):
| 任务类别 | pipeline 名称 | 典型模型 | 模型大小 |
|---|---|---|---|
| 文本分类 | text-classification | distilbert-sst-2 | ~67MB |
| 命名实体识别 | token-classification | bert-base-NER | ~110MB |
| 问答 | question-answering | distilbert-squad | ~67MB |
| 文本摘要 | summarization | distilbart-cnn | ~300MB |
| 翻译 | translation | opus-mt-xx-xx | ~150MB |
| 文本生成 | text2text-generation | flan-t5-small | ~77MB |
| 零样本分类 | zero-shot-classification | nli-deberta | ~50MB |
| 特征提取 | feature-extraction | all-MiniLM-L6-v2 | ~23MB |
| 图片分类 | image-classification | vit-base | ~86MB |
| 目标检测 | object-detection | detr-resnet-50 | ~160MB |
| 图像分割 | image-segmentation | detr-panoptic | ~160MB |
缓存机制:
// Transformers.js 默认使用 Cache API
// 首次下载模型后自动缓存,下次加载直接从缓存读取
// 查看缓存大小
const cache = await caches.open('transformers-cache');
const keys = await cache.keys();
console.log(`已缓存 ${keys.length} 个文件`);
// 预加载模型(不立即使用)
await pipeline('text-classification', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english');
// 之后调用时直接从缓存加载,秒级就绪
Q12: 什么场景适合端侧 AI 而不是云端?
答案:
| 场景 | 原因 | 推荐方案 |
|---|---|---|
| 隐私敏感数据 | 医疗影像、金融数据不应离开设备 | ONNX Runtime Web + WebGPU |
| 实时交互 | AR 滤镜、手势识别需要 <50ms 延迟 | MediaPipe |
| 离线场景 | PWA、弱网环境、飞行模式 | Transformers.js + Cache API |
| 成本控制 | 高频低复杂度任务避免 API 费用 | 本地小模型 |
| 预处理 | 客户端做初步分类/嵌入再发云端 | 混合推理架构 |
| 合规要求 | GDPR 等法规限制数据出境 | 端侧推理 |
不适合端侧的场景:
- 需要大模型能力(>10B 参数):当前浏览器只能运行小模型
- 长文本生成/复杂推理:小模型质量不足
- 首次体验要求快:模型下载需要时间
Q13: 如何在前端实现"本地知识库"功能(不依赖后端向量库)?
答案:
前端本地知识库的核心架构是 IndexedDB 存储 + 浏览器端 Embedding + 余弦相似度搜索,完全在浏览器中运行,不需要后端向量数据库。
技术方案:
import { pipeline, type FeatureExtractionPipeline } from '@huggingface/transformers';
interface Document {
id: string;
content: string;
metadata?: Record<string, string>;
embedding?: Float32Array; // 向量存储在 IndexedDB
}
interface SearchResult {
document: Document;
score: number; // 余弦相似度 0-1
}
class LocalKnowledgeBase {
private embedder: FeatureExtractionPipeline | null = null;
private db: IDBDatabase | null = null;
private readonly DB_NAME = 'knowledge-base';
private readonly STORE_NAME = 'documents';
// 1. 初始化嵌入模型(~23MB,首次下载后缓存)
async init(): Promise<void> {
// all-MiniLM-L6-v2:384 维向量,轻量且效果好
this.embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2',
{ device: 'webgpu' } // 优先 WebGPU 加速
) as FeatureExtractionPipeline;
// 打开 IndexedDB
this.db = await new Promise((resolve, reject) => {
const request = indexedDB.open(this.DB_NAME, 1);
request.onupgradeneeded = () => {
const db = request.result;
if (!db.objectStoreNames.contains(this.STORE_NAME)) {
db.createObjectStore(this.STORE_NAME, { keyPath: 'id' });
}
};
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}
// 2. 添加文档:生成 Embedding 并存储
async addDocument(content: string, metadata?: Record<string, string>): Promise<string> {
if (!this.embedder || !this.db) throw new Error('未初始化');
const id = crypto.randomUUID();
// 生成文本的向量表示
const output = await this.embedder(content, { pooling: 'mean', normalize: true });
const embedding = new Float32Array(output.data as ArrayBuffer);
const doc: Document = { id, content, metadata, embedding };
// 存入 IndexedDB
await new Promise<void>((resolve, reject) => {
const tx = this.db!.transaction(this.STORE_NAME, 'readwrite');
tx.objectStore(this.STORE_NAME).put(doc);
tx.oncomplete = () => resolve();
tx.onerror = () => reject(tx.error);
});
return id;
}
// 3. 语义搜索:计算查询向量与所有文档的余弦相似度
async search(query: string, topK = 5): Promise<SearchResult[]> {
if (!this.embedder || !this.db) throw new Error('未初始化');
// 生成查询向量
const output = await this.embedder(query, { pooling: 'mean', normalize: true });
const queryVec = new Float32Array(output.data as ArrayBuffer);
// 从 IndexedDB 读取所有文档
const docs = await new Promise<Document[]>((resolve, reject) => {
const tx = this.db!.transaction(this.STORE_NAME, 'readonly');
const request = tx.objectStore(this.STORE_NAME).getAll();
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
// 计算余弦相似度并排序
const results: SearchResult[] = docs
.filter(doc => doc.embedding)
.map(doc => ({
document: { ...doc, embedding: undefined },
score: cosineSimilarity(queryVec, doc.embedding!),
}))
.sort((a, b) => b.score - a.score)
.slice(0, topK);
return results;
}
}
// 余弦相似度(向量已归一化时等于点积)
function cosineSimilarity(a: Float32Array, b: Float32Array): number {
let dot = 0;
for (let i = 0; i < a.length; i++) dot += a[i] * b[i];
return dot; // 归一化后 dot product = cosine similarity
}
性能优化要点:
| 优化点 | 方案 | 效果 |
|---|---|---|
| Embedding 计算 | 放在 Web Worker 中 | 不阻塞 UI |
| 大量文档搜索 | 使用 HNSW 索引(如 hnswlib-wasm) | O(log n) 搜索,替代 O(n) 暴力搜索 |
| 文档分块 | 长文档按段落/句子拆分(~200-500 字/块) | 提高检索精度 |
| WebGPU 加速 | { device: 'webgpu' } | Embedding 速度提升 3-10 倍 |
局限性:
- 向量维度 384,文档量 >10 万时 IndexedDB 搜索变慢(建议引入 HNSW)
- 浏览器端只能用小型 Embedding 模型(~23MB),语义理解不如 OpenAI text-embedding-3
- 适合个人笔记、本地文档检索等场景;企业级 RAG 仍需后端向量数据库
更多 RAG 架构设计参考 RAG 检索增强生成。 更多向量搜索原理参考 向量搜索与 Embedding。
相关链接
- TensorFlow.js
- ONNX Runtime Web
- Transformers.js
- WebLLM
- MediaPipe
- WebGPU - MDN
- WebNN - W3C
- Comlink - 简化 Worker 通信
- Web Workers - Worker 线程详解
- AI 应用性能优化 - 性能优化策略
- 多模态交互 - 图片/语音/视频多模态 AI