跳到主要内容

Web AI 与端侧推理

问题

前端如何在浏览器中直接运行 AI 模型?WebGPU、WebNN、TensorFlow.js、ONNX Runtime Web、Transformers.js、WebLLM、MediaPipe 等技术的原理、适用场景和最佳实践是什么?如何优化模型加载、推理性能和用户体验?

面试速答版

前端如何在浏览器中直接运行 AI 模型? Web AI(端侧推理)= 把模型直接跑在浏览器里,数据不离开用户设备

  • 核心优势:隐私(不上传敏感数据)、离线可用、零延迟、无 API 成本。
  • 加速底座:WebGPU(最快,接近原生 GPU)、WebNN(调专用 NPU)、WASM SIMD(CPU 兜底)。
  • 典型场景:实时图像分割(背景虚化)、本地翻译、语音识别、轻量级 LLM 对话。

WebGPU、WebNN、TensorFlow.js、ONNX、Transformers.js、WebLLM、MediaPipe 怎么选? 按抽象层级和场景对号入座:

  • 底层加速WebGPU 通用计算、WebNN 专用 NN、WASM 兼容兜底。
  • 通用框架TensorFlow.js(Google 生态)、ONNX Runtime Web(跨框架最广)。
  • NLP 高层Transformers.js(Hugging Face 模型一行接入,BERT / Whisper / 小 LLM 都行)。
  • 本地大模型WebLLM / llama.cpp(WebGPU 跑 7B 模型可达 20 token/s)。
  • 视觉 / 实时MediaPipe(人脸 / 手势 / 姿态识别,已有现成 solution)。

如何优化模型加载、推理性能和用户体验? 按生命周期分别下手:

  • 加载:模型量化(INT8 / INT4 减小 4x 体积)、CDN + Service Worker 缓存、首屏只下载必需的;用 OPFS 把模型存本地避免重下。
  • 推理:批处理、Worker 线程隔离避免阻塞主线程、复用 Tensor 不要每次都 new。
  • 体验:首次加载显示进度条 + 模型大小说明;推理慢时降级走云端 API;不支持 WebGPU 自动 fallback 到 WASM。

答案

Web AI(端侧推理)是在浏览器中直接运行 AI 模型,不需要将数据发送到服务器。这带来了隐私保护、离线可用、低延迟等核心优势。随着 WebGPU、WebNN 等底层加速 API 的成熟,以及 Transformers.js、ONNX Runtime Web 等高层框架的发展,浏览器的 AI 能力正在快速提升。

一、Web AI 技术栈全景

技术说明加速后端状态
WebGPU现代 GPU 加速 API,原生支持 Compute ShaderGPUChrome 113+、Edge 113+、Firefox Nightly
WebNN直接访问 NPU/AI 加速器的标准 APINPU/GPU/CPUChrome 开发中(Origin Trial)
WebAssembly接近原生的 CPU 执行性能CPU所有现代浏览器
TensorFlow.jsGoogle 的 Web ML 框架,支持训练和推理WebGPU/WebGL/WASM成熟,生态丰富
ONNX Runtime Web微软跨平台推理引擎,支持多种模型格式WebGPU/WebNN/WASM生产就绪
Transformers.js直接运行 Hugging Face 上的模型WebGPU/WASMv3 活跃发展中
WebLLM浏览器运行 LLM(Llama、Mistral、Phi)WebGPU实验性,快速迭代
MediaPipeGoogle 实时视觉/手势/人脸检测方案WebGPU/WASM/GPU delegate生产就绪

二、WebGPU 深入 — Compute Shader 与 AI 加速

WebGPU 是 WebGL 的继任者,基于 Vulkan/Metal/D3D12 设计。与 WebGL 最大的区别在于 WebGPU 原生支持 Compute Shader(计算着色器),可以直接进行通用计算(GPGPU),而不需要把计算映射为图形渲染操作。

Compute Shader 基础 — 矩阵乘法示例

矩阵乘法是神经网络中最核心的运算。以下展示如何用 WebGPU Compute Shader 实现矩阵乘法:

lib/webgpu-matmul.ts
// WebGPU 矩阵乘法:C = A × B
async function gpuMatMul(
a: Float32Array,
b: Float32Array,
M: number, // A 的行数
K: number, // A 的列数 = B 的行数
N: number // B 的列数
): Promise<Float32Array> {
// 1. 获取 GPU 设备
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) throw new Error('WebGPU 不可用');
const device = await adapter.requestDevice();

// 2. 创建 GPU Buffer
const bufferA = device.createBuffer({
size: a.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferB = device.createBuffer({
size: b.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferC = device.createBuffer({
size: M * N * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
const bufferDims = device.createBuffer({
size: 3 * 4,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});

// 写入数据
device.queue.writeBuffer(bufferA, 0, a);
device.queue.writeBuffer(bufferB, 0, b);
device.queue.writeBuffer(bufferDims, 0, new Uint32Array([M, K, N]));

// 3. 编写 WGSL Compute Shader
const shaderModule = device.createShaderModule({
code: `
struct Dims {
M: u32,
K: u32,
N: u32,
}
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@group(0) @binding(3) var<uniform> dims: Dims;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let row = id.x;
let col = id.y;
if (row >= dims.M || col >= dims.N) { return; }

var sum: f32 = 0.0;
for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
sum = sum + A[row * dims.K + k] * B[k * dims.N + col];
}
C[row * dims.N + col] = sum;
}
`,
});

// 4. 创建计算管线
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: { module: shaderModule, entryPoint: 'main' },
});

const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: bufferA } },
{ binding: 1, resource: { buffer: bufferB } },
{ binding: 2, resource: { buffer: bufferC } },
{ binding: 3, resource: { buffer: bufferDims } },
],
});

// 5. 提交计算命令
const commandEncoder = device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(Math.ceil(M / 16), Math.ceil(N / 16));
passEncoder.end();

// 读回结果
const readBuffer = device.createBuffer({
size: M * N * 4,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});
commandEncoder.copyBufferToBuffer(bufferC, 0, readBuffer, 0, M * N * 4);
device.queue.submit([commandEncoder.finish()]);

await readBuffer.mapAsync(GPUMapMode.READ);
const result = new Float32Array(readBuffer.getMappedRange().slice(0));
readBuffer.unmap();

return result;
}

WebGPU vs WebGL vs WASM 性能基准

基准测试(矩阵 1024x1024)WebGPUWebGLWASM (SIMD)CPU (JS)
矩阵乘法~8ms~30ms~120ms~800ms
ResNet-50 推理~15ms~60ms~200ms~1500ms
BERT-base 推理~25ms~100ms~350msN/A
加速比 (vs CPU)~50-100x~10-25x~5-8x1x
说明

以上数据来自 Chrome 团队和 ONNX Runtime 的公开基准测试,实际性能因硬件和模型差异会有变化。WebGPU 在大矩阵运算和 Transformer 类模型上优势尤其显著。

WebGPU 浏览器支持情况

浏览器支持状态备注
Chrome / Edge113+ 正式支持桌面端完整支持,Android Chrome 121+
FirefoxNightly 实验性需手动启用 dom.webgpu.enabled
Safari预览版实验性WebKit 正在实现中
iOS Safari暂不支持受限于 WebKit 进度
注意

WebGPU 目前在移动端支持仍然有限。生产环境需做好降级方案,检测 WebGPU 不可用时回退到 WASM 后端。

lib/webgpu-check.ts
// 运行时检测 WebGPU 支持
async function checkWebGPUSupport(): Promise<{
supported: boolean;
adapterInfo?: GPUAdapterInfo;
limits?: GPUSupportedLimits;
}> {
if (!('gpu' in navigator)) {
return { supported: false };
}

const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
return { supported: false };
}

const info = await adapter.requestAdapterInfo();
return {
supported: true,
adapterInfo: info,
limits: adapter.limits,
};
}

// 根据支持情况选择后端
async function selectBackend(): Promise<'webgpu' | 'wasm' | 'webgl'> {
const gpuSupport = await checkWebGPUSupport();
if (gpuSupport.supported) return 'webgpu';

// 检查 WebGL 2
const canvas = document.createElement('canvas');
const gl = canvas.getContext('webgl2');
if (gl) return 'webgl';

return 'wasm';
}

三、WebNN — 原生 AI 加速 API

WebNN(Web Neural Network API)是 W3C 正在制定的标准,允许 Web 应用直接访问设备上的 NPU(神经处理单元)、GPU 和 CPU,实现高效的神经网络推理。

WebNN 架构

WebNN 核心概念

lib/webnn-example.ts
// WebNN API 基础用法
async function webnnInference() {
// 1. 获取 ML Context(指定设备偏好)
const context = await navigator.ml.createContext({
deviceType: 'npu', // 'cpu' | 'gpu' | 'npu'
powerPreference: 'low-power', // 'default' | 'high-performance' | 'low-power'
});

// 2. 创建 GraphBuilder
const builder = new MLGraphBuilder(context);

// 3. 定义计算图(以简单的全连接层为例)
const inputDesc: MLOperandDescriptor = {
dataType: 'float32',
shape: [1, 784], // 28x28 图片展平
};
const input = builder.input('input', inputDesc);

// 权重和偏置(实际场景从模型文件加载)
const weightsData = new Float32Array(784 * 128); // 假设已填充
const weights = builder.constant(
{ dataType: 'float32', shape: [784, 128] },
weightsData
);
const biasData = new Float32Array(128);
const bias = builder.constant(
{ dataType: 'float32', shape: [128] },
biasData
);

// 4. 构建操作
const matmul = builder.matmul(input, weights);
const add = builder.add(matmul, bias);
const output = builder.relu(add);

// 5. 编译计算图
const graph = await builder.build({ output });

// 6. 执行推理
const inputBuffer = new Float32Array(784);
const outputBuffer = new Float32Array(128);

const results = await context.compute(graph, {
input: inputBuffer,
}, {
output: outputBuffer,
});

return results.output;
}

WebNN 支持的操作

操作类别包含操作
元素级add、sub、mul、div、pow、abs、ceil、floor、exp、log、sigmoid、relu、tanh、softmax
矩阵matmul、gemm
卷积conv2d、convTranspose2d
池化averagePool2d、maxPool2d、l2Pool2d
归一化batchNormalization、layerNormalization、instanceNormalization
变形reshape、transpose、concat、split、slice、pad、expand
注意力暂无原生 attention 操作,通过 matmul + softmax 组合
Chrome WebNN 实现状态

Chrome 正在通过 Origin Trial 方式逐步开放 WebNN API。在 Windows 上通过 DirectML 后端支持 NPU/GPU 加速,macOS 上通过 Core ML 后端支持 Apple Neural Engine。开发者可在 chrome://flags/#enable-web-machine-learning-neural-network-api 手动启用。

四、ONNX Runtime Web 深入

ONNX Runtime Web 是微软开发的跨平台推理引擎,支持 ONNX 格式模型。它是目前浏览器端最成熟的通用推理方案之一。

模型转换:从 PyTorch/TensorFlow 到 ONNX

scripts/convert-model.ts
// Node.js 脚本:PyTorch 模型导出为 ONNX(通常在 Python 中完成)
// 以下是 Python 伪代码的 TypeScript 注释说明

/*
# PyTorch → ONNX
import torch
import torch.onnx

model = MyModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
'model.onnx',
opset_version=17, # 推荐 opset 17+
input_names=['input'],
output_names=['output'],
dynamic_axes={ # 支持动态 batch
'input': {0: 'batch'},
'output': {0: 'batch'},
},
)

# TensorFlow → ONNX(使用 tf2onnx)
# python -m tf2onnx.convert --saved-model ./saved_model --output model.onnx --opset 17
*/

// 在浏览器中使用 ONNX Runtime Web
import * as ort from 'onnxruntime-web';

// 配置执行提供者
ort.env.wasm.wasmPaths = '/wasm/';

async function createSession(modelPath: string): Promise<ort.InferenceSession> {
const options: ort.InferenceSession.SessionOptions = {
executionProviders: [
'webgpu', // 首选 WebGPU
'wasm', // 降级到 WASM
],
graphOptimizationLevel: 'all',
enableCpuMemArena: true,
};

return ort.InferenceSession.create(modelPath, options);
}

// 完整的推理流程
async function runInference(
session: ort.InferenceSession,
inputData: Float32Array,
inputShape: number[]
): Promise<Float32Array> {
// 创建输入 Tensor
const inputTensor = new ort.Tensor('float32', inputData, inputShape);

// 获取输入输出名称
const inputName = session.inputNames[0];
const outputName = session.outputNames[0];

// 执行推理
const results = await session.run({ [inputName]: inputTensor });

return results[outputName].data as Float32Array;
}

ONNX 模型量化

量化是将模型权重从 FP32 转换为更低精度(INT8/INT4/FP16)的过程,可以显著减小模型体积和提升推理速度

lib/onnx-quantization.ts
/*
模型量化通常在 Python 中完成:

# INT8 动态量化(推荐,最简单)
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input='model.onnx',
model_output='model_int8.onnx',
weight_type=QuantType.QInt8,
)

# INT4 量化(更小模型,适合 LLM 权重)
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input='model.onnx',
model_output='model_int4.onnx',
weight_type=QuantType.QInt4,
)

# FP16 量化(保留更多精度)
from onnxruntime.transformers import float16
import onnx

model = onnx.load('model.onnx')
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, 'model_fp16.onnx')
*/

// 量化效果对比
interface QuantizationComparison {
type: string;
sizeReduction: string;
speedup: string;
accuracyLoss: string;
useCase: string;
}

const QUANTIZATION_COMPARISON: QuantizationComparison[] = [
{
type: 'FP32(原始)',
sizeReduction: '基准',
speedup: '基准',
accuracyLoss: '无',
useCase: '训练、高精度需求',
},
{
type: 'FP16',
sizeReduction: '~50%',
speedup: '~1.5-2x',
accuracyLoss: '极小(<0.1%)',
useCase: 'GPU 推理、通用场景',
},
{
type: 'INT8',
sizeReduction: '~75%',
speedup: '~2-4x',
accuracyLoss: '小(<1%)',
useCase: '浏览器端推理(推荐)',
},
{
type: 'INT4',
sizeReduction: '~87.5%',
speedup: '~3-6x',
accuracyLoss: '中(1-3%)',
useCase: 'LLM 权重量化、极小模型',
},
];
量化类型体积缩减速度提升精度损失适用场景
FP32(原始)基准基准训练、高精度需求
FP16~50%~1.5-2x极小(<0.1%)GPU 推理、通用场景
INT8~75%~2-4x小(<1%)浏览器端推理(推荐)
INT4~87.5%~3-6x中(1-3%)LLM 权重量化、极小模型

ONNX Runtime Web 执行提供者(Execution Providers)

lib/onnx-providers.ts
import * as ort from 'onnxruntime-web';

// 不同执行提供者的特点和选择策略
interface ExecutionProvider {
name: string;
backend: string;
pros: string[];
cons: string[];
}

const PROVIDERS: ExecutionProvider[] = [
{
name: 'webgpu',
backend: 'GPU (WebGPU API)',
pros: ['最快的 GPU 加速', '支持 Compute Shader', '内存带宽大'],
cons: ['需要 Chrome 113+', '移动端支持有限'],
},
{
name: 'webnn',
backend: 'NPU/GPU/CPU (WebNN API)',
pros: ['可利用 NPU 加速', '低功耗', '未来标准方向'],
cons: ['浏览器支持极早期', '操作覆盖不完整'],
},
{
name: 'wasm',
backend: 'CPU (WebAssembly)',
pros: ['兼容性最好', '所有浏览器支持', 'SIMD 加速'],
cons: ['性能不如 GPU', '大模型推理较慢'],
},
{
name: 'webgl',
backend: 'GPU (WebGL 2)',
pros: ['广泛的 GPU 支持', '兼容旧浏览器'],
cons: ['计算需映射为纹理操作', '比 WebGPU 慢 3-5x'],
},
];

// 智能选择执行提供者
async function getOptimalProviders(): Promise<string[]> {
const providers: string[] = [];

// 优先 WebGPU
if ('gpu' in navigator) {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) providers.push('webgpu');
}

// 其次 WebNN
if ('ml' in navigator) {
providers.push('webnn');
}

// 兜底 WASM(始终可用)
providers.push('wasm');

return providers;
}

五、Transformers.js 深入

Transformers.js 是 Hugging Face 官方的 JavaScript 版本,让你可以直接在浏览器中运行 Hugging Face 上的数千个模型。v3 版本支持 WebGPU 加速。

Pipeline API 全任务支持

lib/transformers-pipelines.ts
import { pipeline, env, Pipeline } from '@huggingface/transformers';

// 配置(v3 使用 @huggingface/transformers 包名)
env.allowLocalModels = false;

// ===== 自然语言处理任务 =====

// 1. 文本分类(情感分析)
async function textClassification(text: string) {
const classifier = await pipeline(
'text-classification',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english'
);
return classifier(text);
// [{ label: 'POSITIVE', score: 0.9998 }]
}

// 2. 命名实体识别(NER)
async function tokenClassification(text: string) {
const ner = await pipeline(
'token-classification',
'Xenova/bert-base-NER'
);
return ner(text);
// [{ entity: 'B-PER', score: 0.99, word: 'John', ... }]
}

// 3. 问答
async function questionAnswering(question: string, context: string) {
const qa = await pipeline(
'question-answering',
'Xenova/distilbert-base-cased-distilled-squad'
);
return qa({ question, context });
// { answer: 'Paris', score: 0.98, start: 12, end: 17 }
}

// 4. 文本摘要
async function summarization(text: string) {
const summarizer = await pipeline(
'summarization',
'Xenova/distilbart-cnn-6-6'
);
return summarizer(text, { max_length: 100, min_length: 30 });
// [{ summary_text: '摘要内容...' }]
}

// 5. 翻译
async function translation(text: string) {
const translator = await pipeline(
'translation',
'Xenova/opus-mt-zh-en' // 中文到英文
);
return translator(text);
// [{ translation_text: 'translated text...' }]
}

// 6. 文本生成
async function textGeneration(prompt: string) {
const generator = await pipeline(
'text2text-generation',
'Xenova/flan-t5-small'
);
return generator(prompt, { max_new_tokens: 100 });
}

// 7. 零样本分类(不需要训练的分类)
async function zeroShotClassification(text: string, labels: string[]) {
const classifier = await pipeline(
'zero-shot-classification',
'Xenova/nli-deberta-v3-xsmall'
);
return classifier(text, labels);
// { labels: ['技术', '体育', '娱乐'], scores: [0.92, 0.05, 0.03] }
}

// 8. 特征提取(文本嵌入向量)
async function featureExtraction(text: string) {
const embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2'
);
const output = await embedder(text, { pooling: 'mean', normalize: true });
return Array.from(output.data); // 384 维向量
}

// ===== 计算机视觉任务 =====

// 9. 图片分类
async function imageClassification(imageUrl: string) {
const classifier = await pipeline(
'image-classification',
'Xenova/vit-base-patch16-224'
);
return classifier(imageUrl);
// [{ label: 'golden retriever', score: 0.89 }]
}

// 10. 目标检测
async function objectDetection(imageUrl: string) {
const detector = await pipeline(
'object-detection',
'Xenova/detr-resnet-50'
);
return detector(imageUrl);
// [{ label: 'cat', score: 0.98, box: { xmin, ymin, xmax, ymax } }]
}

// 11. 图像分割
async function imageSegmentation(imageUrl: string) {
const segmenter = await pipeline(
'image-segmentation',
'Xenova/detr-resnet-50-panoptic'
);
return segmenter(imageUrl);
// [{ label: 'cat', score: 0.99, mask: RawImage }]
}

模型缓存机制

Transformers.js 默认使用浏览器的 Cache API 缓存已下载的模型文件。理解缓存机制对优化加载体验至关重要:

lib/transformers-cache.ts
import { pipeline, env } from '@huggingface/transformers';

// 缓存配置
env.useBrowserCache = true; // 默认开启 Cache API 缓存
env.allowLocalModels = false; // 是否从本地加载

// 自定义缓存路径
env.cacheDir = '/models'; // 自定义缓存目录名

// 手动管理模型缓存
class ModelCacheManager {
private cacheName = 'transformers-cache';

// 查看已缓存的模型
async listCachedModels(): Promise<string[]> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
return keys.map(req => req.url);
}

// 获取缓存大小
async getCacheSize(): Promise<number> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
let totalSize = 0;

for (const key of keys) {
const response = await cache.match(key);
if (response) {
const blob = await response.blob();
totalSize += blob.size;
}
}
return totalSize;
}

// 清理指定模型缓存
async clearModelCache(modelName: string): Promise<void> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();

for (const key of keys) {
if (key.url.includes(modelName)) {
await cache.delete(key);
}
}
}

// 预加载模型(提前下载不立即使用)
async preloadModel(
task: string,
modelId: string,
onProgress?: (progress: number) => void
): Promise<void> {
await pipeline(task as any, modelId, {
progress_callback: (data: { status: string; progress?: number }) => {
if (data.status === 'progress' && data.progress !== undefined) {
onProgress?.(data.progress);
}
},
});
}
}

六、WebLLM 深入 — 浏览器中运行 LLM

WebLLM 基于 Apache TVM 的 MLC(Machine Learning Compilation)技术,将 LLM 编译为高效的 WebGPU 代码,在浏览器中运行。

支持的模型及资源需求

模型量化模型大小VRAM 需求大约速度
Llama-3.1-8Bq4f16_1~4.5 GB~6 GB25-45 tok/s
Llama-3.2-3Bq4f16_1~1.8 GB~3 GB40-70 tok/s
Mistral-7Bq4f16_1~4.0 GB~5.5 GB25-45 tok/s
Phi-3.5-mini-3.8Bq4f16_1~2.2 GB~3.5 GB35-60 tok/s
Qwen2.5-1.5Bq4f16_1~1.0 GB~2 GB50-80 tok/s
SmolLM2-1.7Bq4f16_1~1.0 GB~2 GB50-80 tok/s
Gemma-2-2Bq4f16_1~1.5 GB~2.5 GB40-65 tok/s
性能说明

以上速度为在搭载 NVIDIA RTX 3060 (12GB) 或 Apple M1 Pro 等中高端 GPU 上的大致范围。实际速度取决于 GPU 型号、显存带宽、浏览器版本。低端 GPU 速度可能减半。

完整的 WebLLM 集成

lib/webllm-engine.ts
import * as webllm from '@mlc-ai/web-llm';

interface WebLLMConfig {
modelId: string;
onProgress?: (progress: { text: string; progress: number }) => void;
onReady?: () => void;
}

class WebLLMEngine {
private engine: webllm.MLCEngine | null = null;
private isReady = false;

async init(config: WebLLMConfig): Promise<void> {
this.engine = await webllm.CreateMLCEngine(config.modelId, {
initProgressCallback: (progress) => {
config.onProgress?.({
text: progress.text,
progress: progress.progress,
});
},
});
this.isReady = true;
config.onReady?.();
}

// OpenAI 兼容的聊天接口
async chat(
messages: Array<{ role: string; content: string }>,
options?: { temperature?: number; max_tokens?: number }
): Promise<string> {
if (!this.engine) throw new Error('引擎未初始化');

const response = await this.engine.chat.completions.create({
messages: messages as webllm.ChatCompletionMessageParam[],
temperature: options?.temperature ?? 0.7,
max_tokens: options?.max_tokens ?? 1024,
});

return response.choices[0]?.message?.content ?? '';
}

// 流式聊天
async *chatStream(
messages: Array<{ role: string; content: string }>,
options?: { temperature?: number; max_tokens?: number }
): AsyncGenerator<string> {
if (!this.engine) throw new Error('引擎未初始化');

const response = await this.engine.chat.completions.create({
messages: messages as webllm.ChatCompletionMessageParam[],
temperature: options?.temperature ?? 0.7,
max_tokens: options?.max_tokens ?? 1024,
stream: true,
});

for await (const chunk of response) {
const content = chunk.choices[0]?.delta?.content;
if (content) yield content;
}
}

// 获取性能统计
async getStats(): Promise<{ tokensPerSecond: number; totalTokens: number }> {
if (!this.engine) throw new Error('引擎未初始化');
const stats = await this.engine.runtimeStatsText();
// 解析统计文本
return { tokensPerSecond: 0, totalTokens: 0 }; // 简化示例
}

dispose(): void {
this.engine = null;
this.isReady = false;
}
}
components/LocalLLMChat.tsx
import { useState, useRef, useCallback } from 'react';

interface Message {
role: 'user' | 'assistant';
content: string;
}

export function LocalLLMChat() {
const [messages, setMessages] = useState<Message[]>([]);
const [input, setInput] = useState('');
const [loading, setLoading] = useState(false);
const [progress, setProgress] = useState<{ text: string; percent: number } | null>(null);
const engineRef = useRef<WebLLMEngine | null>(null);

const initEngine = useCallback(async () => {
const engine = new WebLLMEngine();
await engine.init({
modelId: 'Llama-3.2-3B-Instruct-q4f16_1-MLC',
onProgress: (p) => setProgress({ text: p.text, percent: p.progress * 100 }),
onReady: () => setProgress(null),
});
engineRef.current = engine;
}, []);

const sendMessage = useCallback(async () => {
if (!engineRef.current || !input.trim()) return;

const userMsg: Message = { role: 'user', content: input };
const newMessages = [...messages, userMsg];
setMessages(newMessages);
setInput('');
setLoading(true);

// 流式生成回复
let assistantContent = '';
setMessages([...newMessages, { role: 'assistant', content: '' }]);

for await (const chunk of engineRef.current.chatStream(newMessages)) {
assistantContent += chunk;
setMessages([
...newMessages,
{ role: 'assistant', content: assistantContent },
]);
}

setLoading(false);
}, [input, messages]);

return (
<div className="flex flex-col h-screen">
{/* 加载进度 */}
{progress && (
<div className="p-4 bg-blue-50">
<p className="text-sm">{progress.text}</p>
<div className="w-full bg-gray-200 rounded h-2 mt-1">
<div
className="bg-blue-500 h-2 rounded transition-all"
style={{ width: `${progress.percent}%` }}
/>
</div>
</div>
)}

{/* 消息列表 */}
<div className="flex-1 overflow-auto p-4">
{messages.map((msg, i) => (
<div key={i} className={`mb-4 ${msg.role === 'user' ? 'text-right' : ''}`}>
<div className={`inline-block p-3 rounded-lg ${
msg.role === 'user' ? 'bg-blue-500 text-white' : 'bg-gray-100'
}`}>
{msg.content}
</div>
</div>
))}
</div>

{/* 输入区 */}
<div className="p-4 border-t flex gap-2">
{!engineRef.current ? (
<button onClick={initEngine} className="btn-primary">
加载模型(~1.8GB)
</button>
) : (
<>
<input
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => e.key === 'Enter' && sendMessage()}
className="flex-1 border rounded px-3"
placeholder="输入消息..."
/>
<button onClick={sendMessage} disabled={loading}>
{loading ? '生成中...' : '发送'}
</button>
</>
)}
</div>
</div>
);
}
注意

WebLLM 需要 WebGPU 支持且模型下载较大(1-4GB)。首次加载需要较长时间用于下载和编译模型。适合对隐私要求极高或需离线使用的场景。建议在用户明确触发后才开始加载模型,并配合进度条提供反馈。

七、MediaPipe 实时视觉处理

MediaPipe 是 Google 开发的跨平台机器学习框架,专注于实时视觉处理任务,包括手势识别、人脸网格、姿态检测、目标检测等。其模型轻量级,推理速度极快,适合在浏览器中运行。

MediaPipe 支持的视觉任务

任务模型输出典型帧率
手势识别Hand Landmarker21 个手部关键点 + 手势分类30+ FPS
人脸网格Face Landmarker478 个面部关键点 + 表情分类30+ FPS
姿态检测Pose Landmarker33 个身体关键点30+ FPS
目标检测Object Detector边界框 + 类别 + 置信度25+ FPS
图像分类Image Classifier类别 + 置信度30+ FPS
图像分割Image Segmenter像素级分割掩码20+ FPS
文字识别Text RecognizerOCR 文字20+ FPS

手势识别 React 组件

components/HandGestureDetector.tsx
import { useEffect, useRef, useState, useCallback } from 'react';
import {
GestureRecognizer,
FilesetResolver,
GestureRecognizerResult,
} from '@mediapipe/tasks-vision';

interface DetectedGesture {
gesture: string;
confidence: number;
landmarks: Array<{ x: number; y: number; z: number }>;
}

export function HandGestureDetector() {
const videoRef = useRef<HTMLVideoElement>(null);
const canvasRef = useRef<HTMLCanvasElement>(null);
const [recognizer, setRecognizer] = useState<GestureRecognizer | null>(null);
const [gestures, setGestures] = useState<DetectedGesture[]>([]);
const [fps, setFps] = useState(0);
const frameCountRef = useRef(0);
const lastTimeRef = useRef(performance.now());

// 初始化手势识别器
useEffect(() => {
async function init() {
const vision = await FilesetResolver.forVisionTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
);

const gestureRecognizer = await GestureRecognizer.createFromOptions(vision, {
baseOptions: {
modelAssetPath:
'https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task',
delegate: 'GPU', // 使用 GPU 加速
},
runningMode: 'VIDEO',
numHands: 2,
});

setRecognizer(gestureRecognizer);
}
init();
}, []);

// 开启摄像头
useEffect(() => {
if (!videoRef.current) return;
navigator.mediaDevices
.getUserMedia({ video: { width: 640, height: 480, facingMode: 'user' } })
.then((stream) => {
videoRef.current!.srcObject = stream;
});
}, []);

// 实时检测循环
const detect = useCallback(() => {
if (!recognizer || !videoRef.current || !canvasRef.current) return;

const video = videoRef.current;
const canvas = canvasRef.current;
const ctx = canvas.getContext('2d')!;

const processFrame = () => {
if (video.readyState < 2) {
requestAnimationFrame(processFrame);
return;
}

// 执行手势识别
const result = recognizer.recognizeForVideo(video, performance.now());
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(video, 0, 0);

// 绘制手部关键点和连线
drawHandLandmarks(ctx, result);

// 更新识别结果
const detected: DetectedGesture[] = [];
if (result.gestures.length > 0) {
result.gestures.forEach((gestureList, handIndex) => {
const topGesture = gestureList[0];
detected.push({
gesture: topGesture.categoryName,
confidence: topGesture.score,
landmarks: result.landmarks[handIndex],
});
});
}
setGestures(detected);

// 计算 FPS
frameCountRef.current++;
const now = performance.now();
if (now - lastTimeRef.current >= 1000) {
setFps(frameCountRef.current);
frameCountRef.current = 0;
lastTimeRef.current = now;
}

requestAnimationFrame(processFrame);
};

processFrame();
}, [recognizer]);

useEffect(() => {
detect();
}, [detect]);

return (
<div className="relative">
<video ref={videoRef} autoPlay muted playsInline className="hidden" />
<canvas ref={canvasRef} width={640} height={480} />

{/* 识别结果 */}
<div className="absolute top-2 left-2 bg-black/70 text-white p-2 rounded">
<p>FPS: {fps}</p>
{gestures.map((g, i) => (
<p key={i}>
{i + 1}: {g.gesture}{(g.confidence * 100).toFixed(1)}%
</p>
))}
</div>
</div>
);
}

// 绘制手部关键点
function drawHandLandmarks(
ctx: CanvasRenderingContext2D,
result: GestureRecognizerResult
): void {
const { landmarks } = result;
if (!landmarks.length) return;

// 手部连线定义(简化)
const connections = [
[0, 1], [1, 2], [2, 3], [3, 4], // 拇指
[0, 5], [5, 6], [6, 7], [7, 8], // 食指
[0, 9], [9, 10], [10, 11], [11, 12], // 中指
[0, 13], [13, 14], [14, 15], [15, 16], // 无名指
[0, 17], [17, 18], [18, 19], [19, 20], // 小指
[5, 9], [9, 13], [13, 17], // 手掌
];

for (const handLandmarks of landmarks) {
// 画连线
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
for (const [start, end] of connections) {
const p1 = handLandmarks[start];
const p2 = handLandmarks[end];
ctx.beginPath();
ctx.moveTo(p1.x * ctx.canvas.width, p1.y * ctx.canvas.height);
ctx.lineTo(p2.x * ctx.canvas.width, p2.y * ctx.canvas.height);
ctx.stroke();
}

// 画关键点
ctx.fillStyle = '#FF0000';
for (const point of handLandmarks) {
ctx.beginPath();
ctx.arc(
point.x * ctx.canvas.width,
point.y * ctx.canvas.height,
4, 0, 2 * Math.PI
);
ctx.fill();
}
}
}

人脸网格检测

lib/face-mesh.ts
import {
FaceLandmarker,
FilesetResolver,
FaceLandmarkerResult,
} from '@mediapipe/tasks-vision';

async function createFaceDetector(): Promise<FaceLandmarker> {
const vision = await FilesetResolver.forVisionTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
);

return FaceLandmarker.createFromOptions(vision, {
baseOptions: {
modelAssetPath:
'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task',
delegate: 'GPU',
},
runningMode: 'VIDEO',
numFaces: 1,
outputFaceBlendshapes: true, // 输出表情混合形状(52 个表情参数)
outputFacialTransformationMatrixes: true, // 输出面部变换矩阵
});
}

// 表情分类(基于 blendshapes)
function classifyExpression(
blendshapes: Array<{ categoryName: string; score: number }>
): string {
const shapes: Record<string, number> = {};
for (const bs of blendshapes) {
shapes[bs.categoryName] = bs.score;
}

// 基于关键 blendshapes 判断表情
if (shapes['mouthSmileLeft'] > 0.5 && shapes['mouthSmileRight'] > 0.5) {
return '微笑';
}
if (shapes['browDownLeft'] > 0.5 && shapes['browDownRight'] > 0.5) {
return '皱眉';
}
if (shapes['jawOpen'] > 0.5) {
return '张嘴';
}
if (shapes['eyeBlinkLeft'] > 0.5 && shapes['eyeBlinkRight'] > 0.5) {
return '眨眼';
}
return '正常';
}

八、模型优化策略

在浏览器中运行 AI 模型,模型体积和推理速度是关键瓶颈。以下是主要的优化策略:

1. 量化(Quantization)

量化方式说明体积变化精度影响
动态量化权重 INT8 + 激活运行时量化减少 ~75%极小
静态量化权重和激活都预先量化减少 ~75%小,需校准数据
量化感知训练 (QAT)训练过程中模拟量化减少 ~75%最小
GPTQ/AWQLLM 专用权重量化减少 ~75-87.5%中等

2. 模型剪枝(Pruning)

lib/model-optimization.ts
// 模型剪枝的概念说明(剪枝通常在 Python 训练端完成)
/*
模型剪枝分为:

1. 非结构化剪枝(Unstructured Pruning)
- 将小于阈值的权重置零
- 需要稀疏矩阵计算支持
- 压缩率高但硬件加速困难

2. 结构化剪枝(Structured Pruning)
- 移除整个神经元/通道/注意力头
- 直接减少计算量
- 对硬件友好
*/

// 浏览器端可以加载已剪枝的模型
// 剪枝后的模型体积更小、推理更快
interface ModelOptimizationPlan {
strategy: string;
originalSize: string;
optimizedSize: string;
speedup: string;
accuracyLoss: string;
}

const OPTIMIZATION_STRATEGIES: ModelOptimizationPlan[] = [
{
strategy: '仅量化(INT8)',
originalSize: '100 MB',
optimizedSize: '25 MB',
speedup: '2-4x',
accuracyLoss: '<1%',
},
{
strategy: '量化(INT4)+ 剪枝(50%)',
originalSize: '100 MB',
optimizedSize: '6 MB',
speedup: '4-8x',
accuracyLoss: '2-5%',
},
{
strategy: '知识蒸馏(小模型)',
originalSize: '400 MB (teacher)',
optimizedSize: '25 MB (student)',
speedup: '10-20x',
accuracyLoss: '3-8%',
},
{
strategy: '全套优化(蒸馏+量化+剪枝)',
originalSize: '400 MB',
optimizedSize: '3 MB',
speedup: '50-100x',
accuracyLoss: '5-15%',
},
];

3. 知识蒸馏(Knowledge Distillation)

模型大小 vs 精度权衡建议
  • 文本分类/情感分析:DistilBERT(66MB)足够,无需 BERT-base
  • 目标检测:MobileNet-SSD(~10MB)> YOLO-v8n(~12MB)> DETR(~160MB)
  • 文本嵌入:all-MiniLM-L6-v2(~23MB)是体积和质量的最佳平衡
  • LLM:Qwen2.5-1.5B-q4(~1GB)是当前浏览器中运行 LLM 的实用下限

九、渐进式模型加载

模型文件通常较大(10MB - 4GB),需要精心设计加载策略以提升用户体验。

lib/progressive-model-loading.ts
// 渐进式模型加载器 — 带进度、缓存、预热
class ProgressiveModelLoader {
private cacheStoreName = 'ai-models';

// 带进度的模型下载
async downloadWithProgress(
url: string,
onProgress: (loaded: number, total: number) => void
): Promise<ArrayBuffer> {
const response = await fetch(url);
const contentLength = Number(response.headers.get('content-length')) || 0;
const reader = response.body!.getReader();
const chunks: Uint8Array[] = [];
let loaded = 0;

while (true) {
const { done, value } = await reader.read();
if (done) break;

chunks.push(value);
loaded += value.length;
onProgress(loaded, contentLength);
}

// 合并所有 chunks
const buffer = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
buffer.set(chunk, offset);
offset += chunk.length;
}
return buffer.buffer;
}

// 使用 Cache API 缓存模型
async loadWithCache(
modelUrl: string,
onProgress: (stage: string, progress: number) => void
): Promise<ArrayBuffer> {
const cache = await caches.open(this.cacheStoreName);
const cached = await cache.match(modelUrl);

if (cached) {
onProgress('从缓存加载', 1);
return cached.arrayBuffer();
}

onProgress('下载模型', 0);
const buffer = await this.downloadWithProgress(modelUrl, (loaded, total) => {
onProgress('下载模型', total > 0 ? loaded / total : 0);
});

// 存入缓存
await cache.put(modelUrl, new Response(buffer.slice(0)));
onProgress('缓存完成', 1);

return buffer;
}

// 使用 IndexedDB 缓存(适合大模型,不受 Cache API 大小限制)
async loadWithIndexedDB(
modelUrl: string,
modelId: string,
onProgress: (stage: string, progress: number) => void
): Promise<ArrayBuffer> {
// 先检查 IndexedDB
const cachedBuffer = await this.getFromIDB(modelId);
if (cachedBuffer) {
onProgress('从 IndexedDB 加载', 1);
return cachedBuffer;
}

onProgress('下载模型', 0);
const buffer = await this.downloadWithProgress(modelUrl, (loaded, total) => {
onProgress('下载模型', total > 0 ? loaded / total : 0);
});

// 存入 IndexedDB
await this.saveToIDB(modelId, buffer);
onProgress('缓存完成', 1);

return buffer;
}

private getFromIDB(key: string): Promise<ArrayBuffer | null> {
return new Promise((resolve, reject) => {
const request = indexedDB.open('ModelCache', 1);
request.onupgradeneeded = () => {
request.result.createObjectStore('models');
};
request.onsuccess = () => {
const tx = request.result.transaction('models', 'readonly');
const store = tx.objectStore('models');
const getReq = store.get(key);
getReq.onsuccess = () => resolve(getReq.result ?? null);
getReq.onerror = () => reject(getReq.error);
};
});
}

private saveToIDB(key: string, data: ArrayBuffer): Promise<void> {
return new Promise((resolve, reject) => {
const request = indexedDB.open('ModelCache', 1);
request.onupgradeneeded = () => {
request.result.createObjectStore('models');
};
request.onsuccess = () => {
const tx = request.result.transaction('models', 'readwrite');
const store = tx.objectStore('models');
store.put(data, key);
tx.oncomplete = () => resolve();
tx.onerror = () => reject(tx.error);
};
});
}
}

加载进度 UI 组件

components/ModelLoadingUI.tsx
import { useState, useCallback } from 'react';

interface LoadingState {
stage: string;
progress: number;
error?: string;
}

export function ModelLoadingUI({
onLoaded,
}: {
onLoaded: (buffer: ArrayBuffer) => void;
}) {
const [state, setState] = useState<LoadingState | null>(null);
const loader = new ProgressiveModelLoader();

const startLoading = useCallback(async (modelUrl: string) => {
try {
const buffer = await loader.loadWithCache(modelUrl, (stage, progress) => {
setState({ stage, progress });
});

// 模型预热(首次推理通常较慢,先运行一次空推理)
setState({ stage: '模型预热中...', progress: 1 });
onLoaded(buffer);
} catch (error) {
setState({
stage: '加载失败',
progress: 0,
error: (error as Error).message,
});
}
}, [onLoaded]);

if (!state) {
return (
<button onClick={() => startLoading('/models/model.onnx')}>
加载 AI 模型
</button>
);
}

return (
<div className="w-full max-w-md mx-auto p-4">
<p className="text-sm text-gray-600 mb-2">{state.stage}</p>
<div className="w-full bg-gray-200 rounded-full h-3">
<div
className="bg-blue-500 h-3 rounded-full transition-all duration-300"
style={{ width: `${state.progress * 100}%` }}
/>
</div>
<p className="text-xs text-gray-400 mt-1">
{(state.progress * 100).toFixed(1)}%
</p>
{state.error && (
<p className="text-red-500 text-sm mt-2">{state.error}</p>
)}
</div>
);
}

模型预热策略

lib/model-warmup.ts
// 模型预热:首次推理通常比后续慢 2-5 倍(JIT 编译、GPU 管线初始化等)
// 在后台预先运行一次推理可以消除这个延迟

async function warmupONNXModel(session: ort.InferenceSession): Promise<void> {
const inputName = session.inputNames[0];
const inputShape = session.inputNames.length > 0
? (session as any)._model?.graph?.input?.[0]?.type?.tensorType?.shape?.dim?.map(
(d: any) => d.dimValue || 1
) ?? [1, 3, 224, 224]
: [1, 3, 224, 224];

// 创建全零的 dummy 输入
const dummyData = new Float32Array(
inputShape.reduce((a: number, b: number) => a * b, 1)
);
const dummyTensor = new ort.Tensor('float32', dummyData, inputShape);

// 运行一次推理(丢弃结果)
await session.run({ [inputName]: dummyTensor });
}

十、混合推理架构

混合推理(Hybrid Inference)将端侧推理和云端推理结合,利用各自优势:

lib/hybrid-inference.ts
import { pipeline } from '@huggingface/transformers';

interface InferenceResult {
source: 'local' | 'cloud';
result: unknown;
latency: number;
confidence?: number;
}

class HybridInferenceEngine {
private localClassifier: Awaited<ReturnType<typeof pipeline>> | null = null;
private localEmbedder: Awaited<ReturnType<typeof pipeline>> | null = null;
private confidenceThreshold = 0.85;

async init(): Promise<void> {
// 并行加载本地模型
const [classifier, embedder] = await Promise.all([
pipeline('text-classification', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'),
pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'),
]);
this.localClassifier = classifier;
this.localEmbedder = embedder;
}

// 策略1:置信度路由 — 本地推理置信度低则转云端
async classifyWithFallback(text: string): Promise<InferenceResult> {
const start = performance.now();

// 先尝试本地推理
if (this.localClassifier) {
const localResult = await this.localClassifier(text) as any[];
const topResult = localResult[0];

if (topResult.score >= this.confidenceThreshold) {
// 置信度足够,直接返回本地结果
return {
source: 'local',
result: topResult,
latency: performance.now() - start,
confidence: topResult.score,
};
}
}

// 本地置信度不足或模型未加载,回退到云端
const cloudResult = await this.callCloudAPI(text);
return {
source: 'cloud',
result: cloudResult,
latency: performance.now() - start,
};
}

// 策略2:本地预处理 + 云端精处理
async analyzeWithPreprocessing(text: string): Promise<InferenceResult> {
const start = performance.now();

// 本地生成文本嵌入(用于 RAG 检索)
let embedding: number[] | null = null;
if (this.localEmbedder) {
const output = await this.localEmbedder(text, {
pooling: 'mean',
normalize: true,
});
embedding = Array.from(output.data);
}

// 将嵌入向量发送给云端(用于相似度搜索/RAG)
const cloudResult = await this.callCloudAPIWithEmbedding(text, embedding);
return {
source: 'cloud',
result: cloudResult,
latency: performance.now() - start,
};
}

// 策略3:任务复杂度路由
async smartRoute(task: {
type: 'classify' | 'generate' | 'summarize' | 'embed';
input: string;
}): Promise<InferenceResult> {
const start = performance.now();

switch (task.type) {
case 'classify':
case 'embed':
// 分类和嵌入在本地完成
return this.classifyWithFallback(task.input);

case 'generate':
case 'summarize':
// 生成和摘要发送到云端
const result = await this.callCloudAPI(task.input);
return {
source: 'cloud',
result,
latency: performance.now() - start,
};

default:
throw new Error(`未知任务类型: ${task.type}`);
}
}

private async callCloudAPI(text: string): Promise<unknown> {
const response = await fetch('/api/ai/inference', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text }),
});
return response.json();
}

private async callCloudAPIWithEmbedding(
text: string,
embedding: number[] | null
): Promise<unknown> {
const response = await fetch('/api/ai/inference', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text, embedding }),
});
return response.json();
}
}
混合推理适用场景
  • 表单校验/文本分类:本地推理,低延迟,无成本
  • 搜索建议:本地嵌入 + 本地向量搜索
  • RAG 应用:本地嵌入生成 + 云端 LLM 生成回答
  • 图片预筛:本地目标检测/分类 + 云端 Vision 精分析
  • 敏感数据:本地脱敏/分类后再发送云端

十一、Web Worker 深度集成

AI 推理计算量大,必须放在 Web Worker 中避免阻塞主线程。以下是使用 SharedArrayBuffer 和 Comlink 的进阶方案。

Comlink 是 Google Chrome 团队开发的库,将 Worker 通信封装为简单的函数调用:

workers/ai-worker-comlink.ts
// Worker 端
import * as Comlink from 'comlink';
import { pipeline } from '@huggingface/transformers';

class AIWorkerAPI {
private classifier: Awaited<ReturnType<typeof pipeline>> | null = null;
private embedder: Awaited<ReturnType<typeof pipeline>> | null = null;

async initClassifier(
onProgress: (progress: number) => void
): Promise<void> {
this.classifier = await pipeline(
'text-classification',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english',
{
progress_callback: (data: { status: string; progress?: number }) => {
if (data.progress !== undefined) onProgress(data.progress);
},
}
);
}

async initEmbedder(): Promise<void> {
this.embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2'
);
}

async classify(text: string): Promise<Array<{ label: string; score: number }>> {
if (!this.classifier) throw new Error('分类器未初始化');
return this.classifier(text) as any;
}

async embed(text: string): Promise<Float32Array> {
if (!this.embedder) throw new Error('嵌入模型未初始化');
const output = await this.embedder(text, { pooling: 'mean', normalize: true });
return output.data as Float32Array;
}

// 批量处理(利用 SharedArrayBuffer 减少数据拷贝)
async batchEmbed(
texts: string[],
resultBuffer: SharedArrayBuffer // 共享内存
): Promise<void> {
if (!this.embedder) throw new Error('嵌入模型未初始化');

const resultView = new Float32Array(resultBuffer);
const embeddingDim = 384; // all-MiniLM-L6-v2 的维度

for (let i = 0; i < texts.length; i++) {
const output = await this.embedder(texts[i], {
pooling: 'mean',
normalize: true,
});
const embedding = output.data as Float32Array;
resultView.set(embedding, i * embeddingDim);
}
}
}

Comlink.expose(new AIWorkerAPI());
hooks/useAIWorkerComlink.ts
import { useEffect, useRef, useState, useCallback } from 'react';
import * as Comlink from 'comlink';

interface AIWorkerAPI {
initClassifier(onProgress: (progress: number) => void): Promise<void>;
initEmbedder(): Promise<void>;
classify(text: string): Promise<Array<{ label: string; score: number }>>;
embed(text: string): Promise<Float32Array>;
batchEmbed(texts: string[], resultBuffer: SharedArrayBuffer): Promise<void>;
}

export function useAIWorker() {
const workerRef = useRef<Worker | null>(null);
const apiRef = useRef<Comlink.Remote<AIWorkerAPI> | null>(null);
const [isReady, setIsReady] = useState(false);
const [loadProgress, setLoadProgress] = useState(0);

useEffect(() => {
const worker = new Worker(
new URL('../workers/ai-worker-comlink.ts', import.meta.url),
{ type: 'module' }
);
workerRef.current = worker;
apiRef.current = Comlink.wrap<AIWorkerAPI>(worker);

// 初始化模型
apiRef.current
.initClassifier(Comlink.proxy((progress: number) => {
setLoadProgress(progress);
}))
.then(() => setIsReady(true));

return () => worker.terminate();
}, []);

const classify = useCallback(async (text: string) => {
if (!apiRef.current) throw new Error('Worker 未就绪');
return apiRef.current.classify(text); // 像调用本地函数一样简单
}, []);

const batchEmbed = useCallback(async (texts: string[]) => {
if (!apiRef.current) throw new Error('Worker 未就绪');

const embeddingDim = 384;
// 使用 SharedArrayBuffer 避免大数据拷贝
const buffer = new SharedArrayBuffer(texts.length * embeddingDim * 4);
await apiRef.current.batchEmbed(texts, buffer);
return new Float32Array(buffer);
}, []);

return { isReady, loadProgress, classify, batchEmbed };
}
SharedArrayBuffer 安全要求

使用 SharedArrayBuffer 需要页面启用 COOP/COEP 安全头:

Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp

这些头可能导致部分第三方资源加载失败,需要在 CDN 资源上配置 Cross-Origin-Resource-Policy: cross-origin

十二、端侧 vs 云端对比

维度端侧推理云端推理
隐私数据不离开设备数据需发送到服务器
延迟低(无网络延迟)高(网络 + 排队)
离线支持需要网络
模型能力小模型(<10B 参数)大模型(100B+ 参数)
费用免费(使用用户算力)按 token 付费
首次加载慢(需下载模型 10MB-4GB)快(仅发送请求)
兼容性WebGPU 要求较新浏览器全平台
维护模型更新需用户重新下载服务端透明更新
适用场景实时检测、隐私数据、离线、高频低复杂度任务复杂推理、长文本生成、RAG

十三、TensorFlow.js

TensorFlow.js 是 Google 开发的最成熟的 Web ML 框架,拥有丰富的预训练模型生态:

lib/tfjs-example.ts
import * as tf from '@tensorflow/tfjs';

// 1. 加载预训练模型
async function loadModel() {
const model = await tf.loadLayersModel('/models/sentiment/model.json');
return model;
}

// 2. 图片分类
import * as mobilenet from '@tensorflow-models/mobilenet';

async function classifyImage(imageElement: HTMLImageElement) {
const model = await mobilenet.load({ version: 2, alpha: 1.0 });
const predictions = await model.classify(imageElement);

return predictions.map(p => ({
className: p.className,
probability: (p.probability * 100).toFixed(1) + '%',
}));
// [{ className: '金毛猎犬', probability: '89.2%' }]
}

// 3. 姿态检测
import * as poseDetection from '@tensorflow-models/pose-detection';

async function detectPose(video: HTMLVideoElement) {
const detector = await poseDetection.createDetector(
poseDetection.SupportedModels.MoveNet,
{ modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }
);
const poses = await detector.estimatePoses(video);
return poses;
}

常见面试问题

Q1: 浏览器中运行 AI 模型有哪些方案?各自的特点是什么?

答案

方案开发者加速后端模型来源适用场景
TensorFlow.jsGoogleWebGPU/WebGL/WASMTF 模型、TFJS 模型训练+推理,生态最丰富
ONNX Runtime WebMicrosoftWebGPU/WebNN/WASMONNX 格式(PyTorch/TF 转换)跨框架通用推理
Transformers.jsHugging FaceWebGPU/WASMHugging Face 模型NLP/CV 任务,开箱即用
WebLLMMLC-AIWebGPULlama/Mistral/Phi 等 LLM浏览器中运行完整 LLM
MediaPipeGoogleGPU delegate/WASM预训练视觉模型实时视觉处理(手势/人脸/姿态)

加速层优先级:WebGPU(GPU)> WebNN(NPU)> WebAssembly(CPU)

选择建议:

  • 快速原型:Transformers.js(pipeline API 最简单)
  • 生产部署:ONNX Runtime Web(跨框架、执行提供者丰富)
  • 实时视觉:MediaPipe(专为实时优化)
  • 浏览器 LLM:WebLLM(唯一可行方案)

Q2: WebGPU Compute Shader 和 WebGL 在 AI 计算上有什么区别?

答案

特性WebGPU Compute ShaderWebGL (用于 AI)
设计目标通用计算(GPGPU)图形渲染
计算模型原生 Compute Shader,直接操作 buffer需要将计算映射为纹理渲染
数据类型f32, f16, i32, u32, 支持结构体纹理像素(RGBA float)
同步原子操作、workgroup barrier无计算同步原语
内存访问Storage Buffer 直接读写纹理采样,受限于纹理尺寸
性能比 WebGL 快 3-10 倍(矩阵运算)受纹理映射开销限制
API 复杂度较高(类似 Vulkan)中等

WebGPU Compute Shader 的核心优势:

  1. 原生支持 GPGPU:不需要把矩阵乘法伪装为纹理操作
  2. workgroup 共享内存:线程组内共享数据,减少全局内存访问
  3. 灵活的数据布局:Storage Buffer 支持任意结构体
  4. 原子操作:支持线程安全的累加/比较等操作

Q3: 如何将 PyTorch/TensorFlow 模型转换并优化用于浏览器推理?

答案

转换流程

PyTorch 模型 (.pth)
→ torch.onnx.export() → model.onnx
→ 量化工具 → model_int8.onnx
→ 浏览器 ONNX Runtime Web 加载

TensorFlow 模型 (.pb / SavedModel)
→ tf2onnx → model.onnx → 同上

→ tensorflowjs_converter → TFJS 格式 → TensorFlow.js 加载

优化步骤

  1. 模型转换:PyTorch 用 torch.onnx.export(),TensorFlow 用 tf2onnx
  2. 图优化:ONNX Runtime 的 GraphOptimizationLevel.ORT_ENABLE_ALL 自动进行算子融合、常量折叠
  3. 量化:动态量化(INT8)最简单,quantize_dynamic() 一行代码
  4. 分片:大模型拆分为多个分片(每片 <50MB),支持并行下载
  5. 测试:量化后在目标硬件上验证精度损失是否可接受

关键配置

// opset_version 选择 17+(支持更多操作)
// dynamic_axes 设为动态 batch(支持不同输入大小)
// 量化推荐 INT8 动态量化(最简单,精度损失最小)

Q4: 如何实现渐进式模型加载并展示进度?

答案

渐进式模型加载包含四个阶段:

  1. 检查缓存:先查 Cache API 或 IndexedDB 是否有缓存
  2. 下载模型:使用 fetch + ReadableStream 实现带进度的下载
  3. 缓存模型:下载后存入 Cache API(小模型)或 IndexedDB(大模型)
  4. 模型预热:首次推理较慢,后台运行空推理消除 JIT 编译延迟

核心实现:

// 带进度的下载
const response = await fetch(modelUrl);
const total = Number(response.headers.get('content-length'));
const reader = response.body!.getReader();
let loaded = 0;

while (true) {
const { done, value } = await reader.read();
if (done) break;
loaded += value.length;
onProgress(loaded / total); // 更新进度
}

缓存选择

  • Cache API:适合小模型(<200MB),API 简单,自动管理
  • IndexedDB:适合大模型(>200MB),无大小限制,手动管理
  • Transformers.js 默认使用 Cache API,WebLLM 使用 IndexedDB

Q5: 对比 TensorFlow.js、ONNX Runtime Web 和 Transformers.js

答案

维度TensorFlow.jsONNX Runtime WebTransformers.js
开发者GoogleMicrosoftHugging Face
模型格式TF/TFJSONNXONNX(自动转换)
加速后端WebGPU/WebGL/WASMWebGPU/WebNN/WASMWebGPU/WASM
模型训练支持不支持不支持
模型生态TF HubONNX Model ZooHugging Face Hub(最大)
API 易用性中等低(需手动 Tensor)高(pipeline API)
模型加载需转为 TFJS 格式需转为 ONNX 格式自动下载转换
包大小~500KB~200KB~300KB(不含模型)
WebNN 支持不支持支持不支持
社区活跃度极高

选择建议

  • 需要浏览器端训练 → TensorFlow.js
  • 已有 PyTorch/TF 模型需部署 → ONNX Runtime Web
  • 快速使用 NLP/CV 模型 → Transformers.js
  • 需要 WebNN/NPU 加速 → ONNX Runtime Web

Q6: 如何在浏览器中实现实时目标检测?

答案

实时目标检测需要解决三个核心问题:模型选择帧处理循环性能优化

// 使用 Transformers.js 的目标检测 pipeline
import { pipeline } from '@huggingface/transformers';

// 1. 加载模型(一次性)
const detector = await pipeline(
'object-detection',
'Xenova/detr-resnet-50'
);

// 2. 帧处理循环
function startDetection(video: HTMLVideoElement, canvas: HTMLCanvasElement) {
const ctx = canvas.getContext('2d')!;

async function processFrame() {
// 将视频帧绘制到 canvas
ctx.drawImage(video, 0, 0);

// 推理
const results = await detector(canvas.toDataURL(), {
threshold: 0.7, // 置信度阈值
});

// 绘制检测框
for (const { label, score, box } of results) {
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
ctx.strokeRect(box.xmin, box.ymin, box.xmax - box.xmin, box.ymax - box.ymin);
ctx.fillText(`${label} ${(score * 100).toFixed(0)}%`, box.xmin, box.ymin - 5);
}

requestAnimationFrame(processFrame);
}

processFrame();
}

性能优化要点

  1. Worker 隔离:推理放在 Web Worker,避免阻塞 UI
  2. 帧率控制:不需要每帧都推理,可以每 2-3 帧推理一次
  3. 模型选择:DETR (~160MB) 精度高但慢,MobileNet-SSD (~10MB) 速度快但精度低
  4. 输入分辨率:降低输入图片分辨率可大幅提升速度
  5. WebGPU 加速:确保使用 WebGPU 后端而非 WASM

或使用 MediaPipe 获得更好的实时性能(专为实时优化,FPS 更高)。

Q7: 什么是混合推理?什么场景下应该使用?

答案

混合推理是将端侧推理和云端推理结合的架构模式,根据任务复杂度、置信度、网络状态等因素动态选择推理位置。

三种路由策略

策略原理示例
置信度路由本地推理置信度低于阈值时转云端情感分析置信度 <85% 转 GPT-4o
任务复杂度路由简单任务本地、复杂任务云端分类/嵌入本地,生成/摘要云端
预处理+精处理本地做数据预处理后发送云端本地生成嵌入向量 → 云端 RAG

适用场景

  1. RAG 应用:本地嵌入 + 云端 LLM
  2. 图片审核:本地 NSFW 检测 + 可疑内容发云端人工审核
  3. 离线优先:有网络用云端,离线用本地小模型
  4. 成本优化:高频简单任务本地处理,减少 API 调用

Q8: 如何管理浏览器中的 AI 模型缓存(Cache API / IndexedDB)?

答案

方案适用场景大小限制API 复杂度
Cache API小模型(<200MB)浏览器配额(通常 >1GB)简单
IndexedDB大模型(>200MB)更大配额中等
Origin Private File System超大模型磁盘空间较新 API

关键实践:

// Cache API 缓存
const cache = await caches.open('ai-models');
const cached = await cache.match(modelUrl);
if (cached) return cached.arrayBuffer();

// 下载并缓存
const response = await fetch(modelUrl);
await cache.put(modelUrl, response.clone());
return response.arrayBuffer();

注意事项

  1. 存储配额:使用 navigator.storage.estimate() 检查剩余空间
  2. 版本管理:模型 URL 带版本号(如 model_v2_int8.onnx),更新时清理旧版本
  3. 用户提示:大模型下载前提示用户("即将下载 1.8GB 模型")
  4. 清理策略:定期清理不再使用的模型缓存
  5. 持久化存储:调用 navigator.storage.persist() 防止浏览器自动清理

Q9: WebNN 和 WebGPU 有什么区别?各自在 AI 推理中的角色是什么?

答案

维度WebGPUWebNN
设计目标通用 GPU 编程(图形+计算)专用于神经网络推理
硬件访问GPUNPU、GPU、CPU(统一接口)
API 粒度低级(写 Shader、管理 Buffer)高级(定义计算图、执行推理)
优化方式手动(Shader 优化、内存管理)自动(OS 层面优化,利用专用硬件)
功耗较高(GPU 满载)较低(NPU 能效比高)
浏览器支持Chrome 113+ 正式Chrome Origin Trial
主要受益所有需要 GPU 计算的任务神经网络推理,尤其移动端

关系:WebNN 和 WebGPU 不是竞争关系而是互补。WebNN 通过操作系统的原生 ML 框架(DirectML/CoreML/NNAPI)利用 NPU 硬件,功耗更低;WebGPU 提供更灵活的 GPU 计算能力。ONNX Runtime Web 同时支持两者作为执行提供者。

未来趋势:移动设备上 NPU 普及(高通骁龙、Apple Neural Engine),WebNN 的价值将更加突出。

Q10: 端侧 AI 推理会阻塞主线程吗?如何优化?

答案

会阻塞。AI 模型推理涉及大量计算,如果在主线程运行会导致页面卡顿和无响应。

解决方案(按优先级排序):

  1. Web Worker(最重要):将模型加载和推理放在 Worker 线程

    • 使用 Comlink 简化 Worker 通信(像调用本地函数一样)
    • 大数据传输使用 Transferable ObjectsSharedArrayBuffer 减少拷贝
  2. WebGPU 加速:GPU 计算本身在 GPU 上执行,不阻塞 CPU 主线程

    • 但 JS 层的数据准备和结果读取仍在主线程
    • 仍建议配合 Worker 使用
  3. 模型量化:INT8/INT4 量化减少计算量,推理更快

  4. 帧率控制:实时检测不需要每帧推理,可每 2-3 帧推理一次

更多 Worker 使用细节参考 Web Workers。 更多性能优化策略参考 AI 应用性能优化

Q11: Transformers.js 支持哪些 AI 任务?模型如何缓存?

答案

支持的任务(使用 pipeline API):

任务类别pipeline 名称典型模型模型大小
文本分类text-classificationdistilbert-sst-2~67MB
命名实体识别token-classificationbert-base-NER~110MB
问答question-answeringdistilbert-squad~67MB
文本摘要summarizationdistilbart-cnn~300MB
翻译translationopus-mt-xx-xx~150MB
文本生成text2text-generationflan-t5-small~77MB
零样本分类zero-shot-classificationnli-deberta~50MB
特征提取feature-extractionall-MiniLM-L6-v2~23MB
图片分类image-classificationvit-base~86MB
目标检测object-detectiondetr-resnet-50~160MB
图像分割image-segmentationdetr-panoptic~160MB

缓存机制

// Transformers.js 默认使用 Cache API
// 首次下载模型后自动缓存,下次加载直接从缓存读取

// 查看缓存大小
const cache = await caches.open('transformers-cache');
const keys = await cache.keys();
console.log(`已缓存 ${keys.length} 个文件`);

// 预加载模型(不立即使用)
await pipeline('text-classification', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english');
// 之后调用时直接从缓存加载,秒级就绪

Q12: 什么场景适合端侧 AI 而不是云端?

答案

场景原因推荐方案
隐私敏感数据医疗影像、金融数据不应离开设备ONNX Runtime Web + WebGPU
实时交互AR 滤镜、手势识别需要 <50ms 延迟MediaPipe
离线场景PWA、弱网环境、飞行模式Transformers.js + Cache API
成本控制高频低复杂度任务避免 API 费用本地小模型
预处理客户端做初步分类/嵌入再发云端混合推理架构
合规要求GDPR 等法规限制数据出境端侧推理

不适合端侧的场景

  • 需要大模型能力(>10B 参数):当前浏览器只能运行小模型
  • 长文本生成/复杂推理:小模型质量不足
  • 首次体验要求快:模型下载需要时间

Q13: 如何在前端实现"本地知识库"功能(不依赖后端向量库)?

答案

前端本地知识库的核心架构是 IndexedDB 存储 + 浏览器端 Embedding + 余弦相似度搜索,完全在浏览器中运行,不需要后端向量数据库。

技术方案

lib/local-knowledge-base.ts
import { pipeline, type FeatureExtractionPipeline } from '@huggingface/transformers';

interface Document {
id: string;
content: string;
metadata?: Record<string, string>;
embedding?: Float32Array; // 向量存储在 IndexedDB
}

interface SearchResult {
document: Document;
score: number; // 余弦相似度 0-1
}

class LocalKnowledgeBase {
private embedder: FeatureExtractionPipeline | null = null;
private db: IDBDatabase | null = null;
private readonly DB_NAME = 'knowledge-base';
private readonly STORE_NAME = 'documents';

// 1. 初始化嵌入模型(~23MB,首次下载后缓存)
async init(): Promise<void> {
// all-MiniLM-L6-v2:384 维向量,轻量且效果好
this.embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2',
{ device: 'webgpu' } // 优先 WebGPU 加速
) as FeatureExtractionPipeline;

// 打开 IndexedDB
this.db = await new Promise((resolve, reject) => {
const request = indexedDB.open(this.DB_NAME, 1);
request.onupgradeneeded = () => {
const db = request.result;
if (!db.objectStoreNames.contains(this.STORE_NAME)) {
db.createObjectStore(this.STORE_NAME, { keyPath: 'id' });
}
};
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}

// 2. 添加文档:生成 Embedding 并存储
async addDocument(content: string, metadata?: Record<string, string>): Promise<string> {
if (!this.embedder || !this.db) throw new Error('未初始化');

const id = crypto.randomUUID();
// 生成文本的向量表示
const output = await this.embedder(content, { pooling: 'mean', normalize: true });
const embedding = new Float32Array(output.data as ArrayBuffer);

const doc: Document = { id, content, metadata, embedding };

// 存入 IndexedDB
await new Promise<void>((resolve, reject) => {
const tx = this.db!.transaction(this.STORE_NAME, 'readwrite');
tx.objectStore(this.STORE_NAME).put(doc);
tx.oncomplete = () => resolve();
tx.onerror = () => reject(tx.error);
});

return id;
}

// 3. 语义搜索:计算查询向量与所有文档的余弦相似度
async search(query: string, topK = 5): Promise<SearchResult[]> {
if (!this.embedder || !this.db) throw new Error('未初始化');

// 生成查询向量
const output = await this.embedder(query, { pooling: 'mean', normalize: true });
const queryVec = new Float32Array(output.data as ArrayBuffer);

// 从 IndexedDB 读取所有文档
const docs = await new Promise<Document[]>((resolve, reject) => {
const tx = this.db!.transaction(this.STORE_NAME, 'readonly');
const request = tx.objectStore(this.STORE_NAME).getAll();
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});

// 计算余弦相似度并排序
const results: SearchResult[] = docs
.filter(doc => doc.embedding)
.map(doc => ({
document: { ...doc, embedding: undefined },
score: cosineSimilarity(queryVec, doc.embedding!),
}))
.sort((a, b) => b.score - a.score)
.slice(0, topK);

return results;
}
}

// 余弦相似度(向量已归一化时等于点积)
function cosineSimilarity(a: Float32Array, b: Float32Array): number {
let dot = 0;
for (let i = 0; i < a.length; i++) dot += a[i] * b[i];
return dot; // 归一化后 dot product = cosine similarity
}

性能优化要点

优化点方案效果
Embedding 计算放在 Web Worker 中不阻塞 UI
大量文档搜索使用 HNSW 索引(如 hnswlib-wasmO(log n) 搜索,替代 O(n) 暴力搜索
文档分块长文档按段落/句子拆分(~200-500 字/块)提高检索精度
WebGPU 加速{ device: 'webgpu' }Embedding 速度提升 3-10 倍

局限性

  • 向量维度 384,文档量 >10 万时 IndexedDB 搜索变慢(建议引入 HNSW)
  • 浏览器端只能用小型 Embedding 模型(~23MB),语义理解不如 OpenAI text-embedding-3
  • 适合个人笔记、本地文档检索等场景;企业级 RAG 仍需后端向量数据库

更多 RAG 架构设计参考 RAG 检索增强生成。 更多向量搜索原理参考 向量搜索与 Embedding

相关链接