跳到主要内容

Web AI 与端侧推理

问题

前端如何在浏览器中直接运行 AI 模型?WebGPU、WebNN、TensorFlow.js、ONNX Runtime Web、Transformers.js、WebLLM、MediaPipe 等技术的原理、适用场景和最佳实践是什么?如何优化模型加载、推理性能和用户体验?

答案

Web AI(端侧推理)是在浏览器中直接运行 AI 模型,不需要将数据发送到服务器。这带来了隐私保护、离线可用、低延迟等核心优势。随着 WebGPU、WebNN 等底层加速 API 的成熟,以及 Transformers.js、ONNX Runtime Web 等高层框架的发展,浏览器的 AI 能力正在快速提升。

一、Web AI 技术栈全景

技术说明加速后端状态
WebGPU现代 GPU 加速 API,原生支持 Compute ShaderGPUChrome 113+、Edge 113+、Firefox Nightly
WebNN直接访问 NPU/AI 加速器的标准 APINPU/GPU/CPUChrome 开发中(Origin Trial)
WebAssembly接近原生的 CPU 执行性能CPU所有现代浏览器
TensorFlow.jsGoogle 的 Web ML 框架,支持训练和推理WebGPU/WebGL/WASM成熟,生态丰富
ONNX Runtime Web微软跨平台推理引擎,支持多种模型格式WebGPU/WebNN/WASM生产就绪
Transformers.js直接运行 Hugging Face 上的模型WebGPU/WASMv3 活跃发展中
WebLLM浏览器运行 LLM(Llama、Mistral、Phi)WebGPU实验性,快速迭代
MediaPipeGoogle 实时视觉/手势/人脸检测方案WebGPU/WASM/GPU delegate生产就绪

二、WebGPU 深入 — Compute Shader 与 AI 加速

WebGPU 是 WebGL 的继任者,基于 Vulkan/Metal/D3D12 设计。与 WebGL 最大的区别在于 WebGPU 原生支持 Compute Shader(计算着色器),可以直接进行通用计算(GPGPU),而不需要把计算映射为图形渲染操作。

Compute Shader 基础 — 矩阵乘法示例

矩阵乘法是神经网络中最核心的运算。以下展示如何用 WebGPU Compute Shader 实现矩阵乘法:

lib/webgpu-matmul.ts
// WebGPU 矩阵乘法:C = A × B
async function gpuMatMul(
a: Float32Array,
b: Float32Array,
M: number, // A 的行数
K: number, // A 的列数 = B 的行数
N: number // B 的列数
): Promise<Float32Array> {
// 1. 获取 GPU 设备
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) throw new Error('WebGPU 不可用');
const device = await adapter.requestDevice();

// 2. 创建 GPU Buffer
const bufferA = device.createBuffer({
size: a.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferB = device.createBuffer({
size: b.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferC = device.createBuffer({
size: M * N * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
const bufferDims = device.createBuffer({
size: 3 * 4,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});

// 写入数据
device.queue.writeBuffer(bufferA, 0, a);
device.queue.writeBuffer(bufferB, 0, b);
device.queue.writeBuffer(bufferDims, 0, new Uint32Array([M, K, N]));

// 3. 编写 WGSL Compute Shader
const shaderModule = device.createShaderModule({
code: `
struct Dims {
M: u32,
K: u32,
N: u32,
}
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@group(0) @binding(3) var<uniform> dims: Dims;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let row = id.x;
let col = id.y;
if (row >= dims.M || col >= dims.N) { return; }

var sum: f32 = 0.0;
for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
sum = sum + A[row * dims.K + k] * B[k * dims.N + col];
}
C[row * dims.N + col] = sum;
}
`,
});

// 4. 创建计算管线
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: { module: shaderModule, entryPoint: 'main' },
});

const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: bufferA } },
{ binding: 1, resource: { buffer: bufferB } },
{ binding: 2, resource: { buffer: bufferC } },
{ binding: 3, resource: { buffer: bufferDims } },
],
});

// 5. 提交计算命令
const commandEncoder = device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(Math.ceil(M / 16), Math.ceil(N / 16));
passEncoder.end();

// 读回结果
const readBuffer = device.createBuffer({
size: M * N * 4,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});
commandEncoder.copyBufferToBuffer(bufferC, 0, readBuffer, 0, M * N * 4);
device.queue.submit([commandEncoder.finish()]);

await readBuffer.mapAsync(GPUMapMode.READ);
const result = new Float32Array(readBuffer.getMappedRange().slice(0));
readBuffer.unmap();

return result;
}

WebGPU vs WebGL vs WASM 性能基准

基准测试(矩阵 1024x1024)WebGPUWebGLWASM (SIMD)CPU (JS)
矩阵乘法~8ms~30ms~120ms~800ms
ResNet-50 推理~15ms~60ms~200ms~1500ms
BERT-base 推理~25ms~100ms~350msN/A
加速比 (vs CPU)~50-100x~10-25x~5-8x1x
说明

以上数据来自 Chrome 团队和 ONNX Runtime 的公开基准测试,实际性能因硬件和模型差异会有变化。WebGPU 在大矩阵运算和 Transformer 类模型上优势尤其显著。

WebGPU 浏览器支持情况

浏览器支持状态备注
Chrome / Edge113+ 正式支持桌面端完整支持,Android Chrome 121+
FirefoxNightly 实验性需手动启用 dom.webgpu.enabled
Safari预览版实验性WebKit 正在实现中
iOS Safari暂不支持受限于 WebKit 进度
注意

WebGPU 目前在移动端支持仍然有限。生产环境需做好降级方案,检测 WebGPU 不可用时回退到 WASM 后端。

lib/webgpu-check.ts
// 运行时检测 WebGPU 支持
async function checkWebGPUSupport(): Promise<{
supported: boolean;
adapterInfo?: GPUAdapterInfo;
limits?: GPUSupportedLimits;
}> {
if (!('gpu' in navigator)) {
return { supported: false };
}

const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
return { supported: false };
}

const info = await adapter.requestAdapterInfo();
return {
supported: true,
adapterInfo: info,
limits: adapter.limits,
};
}

// 根据支持情况选择后端
async function selectBackend(): Promise<'webgpu' | 'wasm' | 'webgl'> {
const gpuSupport = await checkWebGPUSupport();
if (gpuSupport.supported) return 'webgpu';

// 检查 WebGL 2
const canvas = document.createElement('canvas');
const gl = canvas.getContext('webgl2');
if (gl) return 'webgl';

return 'wasm';
}

三、WebNN — 原生 AI 加速 API

WebNN(Web Neural Network API)是 W3C 正在制定的标准,允许 Web 应用直接访问设备上的 NPU(神经处理单元)、GPU 和 CPU,实现高效的神经网络推理。

WebNN 架构

WebNN 核心概念

lib/webnn-example.ts
// WebNN API 基础用法
async function webnnInference() {
// 1. 获取 ML Context(指定设备偏好)
const context = await navigator.ml.createContext({
deviceType: 'npu', // 'cpu' | 'gpu' | 'npu'
powerPreference: 'low-power', // 'default' | 'high-performance' | 'low-power'
});

// 2. 创建 GraphBuilder
const builder = new MLGraphBuilder(context);

// 3. 定义计算图(以简单的全连接层为例)
const inputDesc: MLOperandDescriptor = {
dataType: 'float32',
shape: [1, 784], // 28x28 图片展平
};
const input = builder.input('input', inputDesc);

// 权重和偏置(实际场景从模型文件加载)
const weightsData = new Float32Array(784 * 128); // 假设已填充
const weights = builder.constant(
{ dataType: 'float32', shape: [784, 128] },
weightsData
);
const biasData = new Float32Array(128);
const bias = builder.constant(
{ dataType: 'float32', shape: [128] },
biasData
);

// 4. 构建操作
const matmul = builder.matmul(input, weights);
const add = builder.add(matmul, bias);
const output = builder.relu(add);

// 5. 编译计算图
const graph = await builder.build({ output });

// 6. 执行推理
const inputBuffer = new Float32Array(784);
const outputBuffer = new Float32Array(128);

const results = await context.compute(graph, {
input: inputBuffer,
}, {
output: outputBuffer,
});

return results.output;
}

WebNN 支持的操作

操作类别包含操作
元素级add、sub、mul、div、pow、abs、ceil、floor、exp、log、sigmoid、relu、tanh、softmax
矩阵matmul、gemm
卷积conv2d、convTranspose2d
池化averagePool2d、maxPool2d、l2Pool2d
归一化batchNormalization、layerNormalization、instanceNormalization
变形reshape、transpose、concat、split、slice、pad、expand
注意力暂无原生 attention 操作,通过 matmul + softmax 组合
Chrome WebNN 实现状态

Chrome 正在通过 Origin Trial 方式逐步开放 WebNN API。在 Windows 上通过 DirectML 后端支持 NPU/GPU 加速,macOS 上通过 Core ML 后端支持 Apple Neural Engine。开发者可在 chrome://flags/#enable-web-machine-learning-neural-network-api 手动启用。

四、ONNX Runtime Web 深入

ONNX Runtime Web 是微软开发的跨平台推理引擎,支持 ONNX 格式模型。它是目前浏览器端最成熟的通用推理方案之一。

模型转换:从 PyTorch/TensorFlow 到 ONNX

scripts/convert-model.ts
// Node.js 脚本:PyTorch 模型导出为 ONNX(通常在 Python 中完成)
// 以下是 Python 伪代码的 TypeScript 注释说明

/*
# PyTorch → ONNX
import torch
import torch.onnx

model = MyModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
'model.onnx',
opset_version=17, # 推荐 opset 17+
input_names=['input'],
output_names=['output'],
dynamic_axes={ # 支持动态 batch
'input': {0: 'batch'},
'output': {0: 'batch'},
},
)

# TensorFlow → ONNX(使用 tf2onnx)
# python -m tf2onnx.convert --saved-model ./saved_model --output model.onnx --opset 17
*/

// 在浏览器中使用 ONNX Runtime Web
import * as ort from 'onnxruntime-web';

// 配置执行提供者
ort.env.wasm.wasmPaths = '/wasm/';

async function createSession(modelPath: string): Promise<ort.InferenceSession> {
const options: ort.InferenceSession.SessionOptions = {
executionProviders: [
'webgpu', // 首选 WebGPU
'wasm', // 降级到 WASM
],
graphOptimizationLevel: 'all',
enableCpuMemArena: true,
};

return ort.InferenceSession.create(modelPath, options);
}

// 完整的推理流程
async function runInference(
session: ort.InferenceSession,
inputData: Float32Array,
inputShape: number[]
): Promise<Float32Array> {
// 创建输入 Tensor
const inputTensor = new ort.Tensor('float32', inputData, inputShape);

// 获取输入输出名称
const inputName = session.inputNames[0];
const outputName = session.outputNames[0];

// 执行推理
const results = await session.run({ [inputName]: inputTensor });

return results[outputName].data as Float32Array;
}

ONNX 模型量化

量化是将模型权重从 FP32 转换为更低精度(INT8/INT4/FP16)的过程,可以显著减小模型体积和提升推理速度

lib/onnx-quantization.ts
/*
模型量化通常在 Python 中完成:

# INT8 动态量化(推荐,最简单)
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input='model.onnx',
model_output='model_int8.onnx',
weight_type=QuantType.QInt8,
)

# INT4 量化(更小模型,适合 LLM 权重)
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input='model.onnx',
model_output='model_int4.onnx',
weight_type=QuantType.QInt4,
)

# FP16 量化(保留更多精度)
from onnxruntime.transformers import float16
import onnx

model = onnx.load('model.onnx')
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, 'model_fp16.onnx')
*/

// 量化效果对比
interface QuantizationComparison {
type: string;
sizeReduction: string;
speedup: string;
accuracyLoss: string;
useCase: string;
}

const QUANTIZATION_COMPARISON: QuantizationComparison[] = [
{
type: 'FP32(原始)',
sizeReduction: '基准',
speedup: '基准',
accuracyLoss: '无',
useCase: '训练、高精度需求',
},
{
type: 'FP16',
sizeReduction: '~50%',
speedup: '~1.5-2x',
accuracyLoss: '极小(<0.1%)',
useCase: 'GPU 推理、通用场景',
},
{
type: 'INT8',
sizeReduction: '~75%',
speedup: '~2-4x',
accuracyLoss: '小(<1%)',
useCase: '浏览器端推理(推荐)',
},
{
type: 'INT4',
sizeReduction: '~87.5%',
speedup: '~3-6x',
accuracyLoss: '中(1-3%)',
useCase: 'LLM 权重量化、极小模型',
},
];
量化类型体积缩减速度提升精度损失适用场景
FP32(原始)基准基准训练、高精度需求
FP16~50%~1.5-2x极小(<0.1%)GPU 推理、通用场景
INT8~75%~2-4x小(<1%)浏览器端推理(推荐)
INT4~87.5%~3-6x中(1-3%)LLM 权重量化、极小模型

ONNX Runtime Web 执行提供者(Execution Providers)

lib/onnx-providers.ts
import * as ort from 'onnxruntime-web';

// 不同执行提供者的特点和选择策略
interface ExecutionProvider {
name: string;
backend: string;
pros: string[];
cons: string[];
}

const PROVIDERS: ExecutionProvider[] = [
{
name: 'webgpu',
backend: 'GPU (WebGPU API)',
pros: ['最快的 GPU 加速', '支持 Compute Shader', '内存带宽大'],
cons: ['需要 Chrome 113+', '移动端支持有限'],
},
{
name: 'webnn',
backend: 'NPU/GPU/CPU (WebNN API)',
pros: ['可利用 NPU 加速', '低功耗', '未来标准方向'],
cons: ['浏览器支持极早期', '操作覆盖不完整'],
},
{
name: 'wasm',
backend: 'CPU (WebAssembly)',
pros: ['兼容性最好', '所有浏览器支持', 'SIMD 加速'],
cons: ['性能不如 GPU', '大模型推理较慢'],
},
{
name: 'webgl',
backend: 'GPU (WebGL 2)',
pros: ['广泛的 GPU 支持', '兼容旧浏览器'],
cons: ['计算需映射为纹理操作', '比 WebGPU 慢 3-5x'],
},
];

// 智能选择执行提供者
async function getOptimalProviders(): Promise<string[]> {
const providers: string[] = [];

// 优先 WebGPU
if ('gpu' in navigator) {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) providers.push('webgpu');
}

// 其次 WebNN
if ('ml' in navigator) {
providers.push('webnn');
}

// 兜底 WASM(始终可用)
providers.push('wasm');

return providers;
}

五、Transformers.js 深入

Transformers.js 是 Hugging Face 官方的 JavaScript 版本,让你可以直接在浏览器中运行 Hugging Face 上的数千个模型。v3 版本支持 WebGPU 加速。

Pipeline API 全任务支持

lib/transformers-pipelines.ts
import { pipeline, env, Pipeline } from '@huggingface/transformers';

// 配置(v3 使用 @huggingface/transformers 包名)
env.allowLocalModels = false;

// ===== 自然语言处理任务 =====

// 1. 文本分类(情感分析)
async function textClassification(text: string) {
const classifier = await pipeline(
'text-classification',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english'
);
return classifier(text);
// [{ label: 'POSITIVE', score: 0.9998 }]
}

// 2. 命名实体识别(NER)
async function tokenClassification(text: string) {
const ner = await pipeline(
'token-classification',
'Xenova/bert-base-NER'
);
return ner(text);
// [{ entity: 'B-PER', score: 0.99, word: 'John', ... }]
}

// 3. 问答
async function questionAnswering(question: string, context: string) {
const qa = await pipeline(
'question-answering',
'Xenova/distilbert-base-cased-distilled-squad'
);
return qa({ question, context });
// { answer: 'Paris', score: 0.98, start: 12, end: 17 }
}

// 4. 文本摘要
async function summarization(text: string) {
const summarizer = await pipeline(
'summarization',
'Xenova/distilbart-cnn-6-6'
);
return summarizer(text, { max_length: 100, min_length: 30 });
// [{ summary_text: '摘要内容...' }]
}

// 5. 翻译
async function translation(text: string) {
const translator = await pipeline(
'translation',
'Xenova/opus-mt-zh-en' // 中文到英文
);
return translator(text);
// [{ translation_text: 'translated text...' }]
}

// 6. 文本生成
async function textGeneration(prompt: string) {
const generator = await pipeline(
'text2text-generation',
'Xenova/flan-t5-small'
);
return generator(prompt, { max_new_tokens: 100 });
}

// 7. 零样本分类(不需要训练的分类)
async function zeroShotClassification(text: string, labels: string[]) {
const classifier = await pipeline(
'zero-shot-classification',
'Xenova/nli-deberta-v3-xsmall'
);
return classifier(text, labels);
// { labels: ['技术', '体育', '娱乐'], scores: [0.92, 0.05, 0.03] }
}

// 8. 特征提取(文本嵌入向量)
async function featureExtraction(text: string) {
const embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2'
);
const output = await embedder(text, { pooling: 'mean', normalize: true });
return Array.from(output.data); // 384 维向量
}

// ===== 计算机视觉任务 =====

// 9. 图片分类
async function imageClassification(imageUrl: string) {
const classifier = await pipeline(
'image-classification',
'Xenova/vit-base-patch16-224'
);
return classifier(imageUrl);
// [{ label: 'golden retriever', score: 0.89 }]
}

// 10. 目标检测
async function objectDetection(imageUrl: string) {
const detector = await pipeline(
'object-detection',
'Xenova/detr-resnet-50'
);
return detector(imageUrl);
// [{ label: 'cat', score: 0.98, box: { xmin, ymin, xmax, ymax } }]
}

// 11. 图像分割
async function imageSegmentation(imageUrl: string) {
const segmenter = await pipeline(
'image-segmentation',
'Xenova/detr-resnet-50-panoptic'
);
return segmenter(imageUrl);
// [{ label: 'cat', score: 0.99, mask: RawImage }]
}

模型缓存机制

Transformers.js 默认使用浏览器的 Cache API 缓存已下载的模型文件。理解缓存机制对优化加载体验至关重要:

lib/transformers-cache.ts
import { pipeline, env } from '@huggingface/transformers';

// 缓存配置
env.useBrowserCache = true; // 默认开启 Cache API 缓存
env.allowLocalModels = false; // 是否从本地加载

// 自定义缓存路径
env.cacheDir = '/models'; // 自定义缓存目录名

// 手动管理模型缓存
class ModelCacheManager {
private cacheName = 'transformers-cache';

// 查看已缓存的模型
async listCachedModels(): Promise<string[]> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
return keys.map(req => req.url);
}

// 获取缓存大小
async getCacheSize(): Promise<number> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();
let totalSize = 0;

for (const key of keys) {
const response = await cache.match(key);
if (response) {
const blob = await response.blob();
totalSize += blob.size;
}
}
return totalSize;
}

// 清理指定模型缓存
async clearModelCache(modelName: string): Promise<void> {
const cache = await caches.open(this.cacheName);
const keys = await cache.keys();

for (const key of keys) {
if (key.url.includes(modelName)) {
await cache.delete(key);
}
}
}

// 预加载模型(提前下载不立即使用)
async preloadModel(
task: string,
modelId: string,
onProgress?: (progress: number) => void
): Promise<void> {
await pipeline(task as any, modelId, {
progress_callback: (data: { status: string; progress?: number }) => {
if (data.status === 'progress' && data.progress !== undefined) {
onProgress?.(data.progress);
}
},
});
}
}

六、WebLLM 深入 — 浏览器中运行 LLM

WebLLM 基于 Apache TVM 的 MLC(Machine Learning Compilation)技术,将 LLM 编译为高效的 WebGPU 代码,在浏览器中运行。

支持的模型及资源需求

模型量化模型大小VRAM 需求大约速度
Llama-3.1-8Bq4f16_1~4.5 GB~6 GB25-45 tok/s
Llama-3.2-3Bq4f16_1~1.8 GB~3 GB40-70 tok/s
Mistral-7Bq4f16_1~4.0 GB~5.5 GB25-45 tok/s
Phi-3.5-mini-3.8Bq4f16_1~2.2 GB~3.5 GB35-60 tok/s
Qwen2.5-1.5Bq4f16_1~1.0 GB~2 GB50-80 tok/s
SmolLM2-1.7Bq4f16_1~1.0 GB~2 GB50-80 tok/s
Gemma-2-2Bq4f16_1~1.5 GB~2.5 GB40-65 tok/s
性能说明

以上速度为在搭载 NVIDIA RTX 3060 (12GB) 或 Apple M1 Pro 等中高端 GPU 上的大致范围。实际速度取决于 GPU 型号、显存带宽、浏览器版本。低端 GPU 速度可能减半。

完整的 WebLLM 集成

lib/webllm-engine.ts
import * as webllm from '@mlc-ai/web-llm';

interface WebLLMConfig {
modelId: string;
onProgress?: (progress: { text: string; progress: number }) => void;
onReady?: () => void;
}

class WebLLMEngine {
private engine: webllm.MLCEngine | null = null;
private isReady = false;

async init(config: WebLLMConfig): Promise<void> {
this.engine = await webllm.CreateMLCEngine(config.modelId, {
initProgressCallback: (progress) => {
config.onProgress?.({
text: progress.text,
progress: progress.progress,
});
},
});
this.isReady = true;
config.onReady?.();
}

// OpenAI 兼容的聊天接口
async chat(
messages: Array<{ role: string; content: string }>,
options?: { temperature?: number; max_tokens?: number }
): Promise<string> {
if (!this.engine) throw new Error('引擎未初始化');

const response = await this.engine.chat.completions.create({
messages: messages as webllm.ChatCompletionMessageParam[],
temperature: options?.temperature ?? 0.7,
max_tokens: options?.max_tokens ?? 1024,
});

return response.choices[0]?.message?.content ?? '';
}

// 流式聊天
async *chatStream(
messages: Array<{ role: string; content: string }>,
options?: { temperature?: number; max_tokens?: number }
): AsyncGenerator<string> {
if (!this.engine) throw new Error('引擎未初始化');

const response = await this.engine.chat.completions.create({
messages: messages as webllm.ChatCompletionMessageParam[],
temperature: options?.temperature ?? 0.7,
max_tokens: options?.max_tokens ?? 1024,
stream: true,
});

for await (const chunk of response) {
const content = chunk.choices[0]?.delta?.content;
if (content) yield content;
}
}

// 获取性能统计
async getStats(): Promise<{ tokensPerSecond: number; totalTokens: number }> {
if (!this.engine) throw new Error('引擎未初始化');
const stats = await this.engine.runtimeStatsText();
// 解析统计文本
return { tokensPerSecond: 0, totalTokens: 0 }; // 简化示例
}

dispose(): void {
this.engine = null;
this.isReady = false;
}
}
components/LocalLLMChat.tsx
import { useState, useRef, useCallback } from 'react';

interface Message {
role: 'user' | 'assistant';
content: string;
}

export function LocalLLMChat() {
const [messages, setMessages] = useState<Message[]>([]);
const [input, setInput] = useState('');
const [loading, setLoading] = useState(false);
const [progress, setProgress] = useState<{ text: string; percent: number } | null>(null);
const engineRef = useRef<WebLLMEngine | null>(null);

const initEngine = useCallback(async () => {
const engine = new WebLLMEngine();
await engine.init({
modelId: 'Llama-3.2-3B-Instruct-q4f16_1-MLC',
onProgress: (p) => setProgress({ text: p.text, percent: p.progress * 100 }),
onReady: () => setProgress(null),
});
engineRef.current = engine;
}, []);

const sendMessage = useCallback(async () => {
if (!engineRef.current || !input.trim()) return;

const userMsg: Message = { role: 'user', content: input };
const newMessages = [...messages, userMsg];
setMessages(newMessages);
setInput('');
setLoading(true);

// 流式生成回复
let assistantContent = '';
setMessages([...newMessages, { role: 'assistant', content: '' }]);

for await (const chunk of engineRef.current.chatStream(newMessages)) {
assistantContent += chunk;
setMessages([
...newMessages,
{ role: 'assistant', content: assistantContent },
]);
}

setLoading(false);
}, [input, messages]);

return (
<div className="flex flex-col h-screen">
{/* 加载进度 */}
{progress && (
<div className="p-4 bg-blue-50">
<p className="text-sm">{progress.text}</p>
<div className="w-full bg-gray-200 rounded h-2 mt-1">
<div
className="bg-blue-500 h-2 rounded transition-all"
style={{ width: `${progress.percent}%` }}
/>
</div>
</div>
)}

{/* 消息列表 */}
<div className="flex-1 overflow-auto p-4">
{messages.map((msg, i) => (
<div key={i} className={`mb-4 ${msg.role === 'user' ? 'text-right' : ''}`}>
<div className={`inline-block p-3 rounded-lg ${
msg.role === 'user' ? 'bg-blue-500 text-white' : 'bg-gray-100'
}`}>
{msg.content}
</div>
</div>
))}
</div>

{/* 输入区 */}
<div className="p-4 border-t flex gap-2">
{!engineRef.current ? (
<button onClick={initEngine} className="btn-primary">
加载模型(~1.8GB)
</button>
) : (
<>
<input
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => e.key === 'Enter' && sendMessage()}
className="flex-1 border rounded px-3"
placeholder="输入消息..."
/>
<button onClick={sendMessage} disabled={loading}>
{loading ? '生成中...' : '发送'}
</button>
</>
)}
</div>
</div>
);
}
注意

WebLLM 需要 WebGPU 支持且模型下载较大(1-4GB)。首次加载需要较长时间用于下载和编译模型。适合对隐私要求极高或需离线使用的场景。建议在用户明确触发后才开始加载模型,并配合进度条提供反馈。

七、MediaPipe 实时视觉处理

MediaPipe 是 Google 开发的跨平台机器学习框架,专注于实时视觉处理任务,包括手势识别、人脸网格、姿态检测、目标检测等。其模型轻量级,推理速度极快,适合在浏览器中运行。

MediaPipe 支持的视觉任务

任务模型输出典型帧率
手势识别Hand Landmarker21 个手部关键点 + 手势分类30+ FPS
人脸网格Face Landmarker478 个面部关键点 + 表情分类30+ FPS
姿态检测Pose Landmarker33 个身体关键点30+ FPS
目标检测Object Detector边界框 + 类别 + 置信度25+ FPS
图像分类Image Classifier类别 + 置信度30+ FPS
图像分割Image Segmenter像素级分割掩码20+ FPS
文字识别Text RecognizerOCR 文字20+ FPS

手势识别 React 组件

components/HandGestureDetector.tsx
import { useEffect, useRef, useState, useCallback } from 'react';
import {
GestureRecognizer,
FilesetResolver,
GestureRecognizerResult,
} from '@mediapipe/tasks-vision';

interface DetectedGesture {
gesture: string;
confidence: number;
landmarks: Array<{ x: number; y: number; z: number }>;
}

export function HandGestureDetector() {
const videoRef = useRef<HTMLVideoElement>(null);
const canvasRef = useRef<HTMLCanvasElement>(null);
const [recognizer, setRecognizer] = useState<GestureRecognizer | null>(null);
const [gestures, setGestures] = useState<DetectedGesture[]>([]);
const [fps, setFps] = useState(0);
const frameCountRef = useRef(0);
const lastTimeRef = useRef(performance.now());

// 初始化手势识别器
useEffect(() => {
async function init() {
const vision = await FilesetResolver.forVisionTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
);

const gestureRecognizer = await GestureRecognizer.createFromOptions(vision, {
baseOptions: {
modelAssetPath:
'https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task',
delegate: 'GPU', // 使用 GPU 加速
},
runningMode: 'VIDEO',
numHands: 2,
});

setRecognizer(gestureRecognizer);
}
init();
}, []);

// 开启摄像头
useEffect(() => {
if (!videoRef.current) return;
navigator.mediaDevices
.getUserMedia({ video: { width: 640, height: 480, facingMode: 'user' } })
.then((stream) => {
videoRef.current!.srcObject = stream;
});
}, []);

// 实时检测循环
const detect = useCallback(() => {
if (!recognizer || !videoRef.current || !canvasRef.current) return;

const video = videoRef.current;
const canvas = canvasRef.current;
const ctx = canvas.getContext('2d')!;

const processFrame = () => {
if (video.readyState < 2) {
requestAnimationFrame(processFrame);
return;
}

// 执行手势识别
const result = recognizer.recognizeForVideo(video, performance.now());
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(video, 0, 0);

// 绘制手部关键点和连线
drawHandLandmarks(ctx, result);

// 更新识别结果
const detected: DetectedGesture[] = [];
if (result.gestures.length > 0) {
result.gestures.forEach((gestureList, handIndex) => {
const topGesture = gestureList[0];
detected.push({
gesture: topGesture.categoryName,
confidence: topGesture.score,
landmarks: result.landmarks[handIndex],
});
});
}
setGestures(detected);

// 计算 FPS
frameCountRef.current++;
const now = performance.now();
if (now - lastTimeRef.current >= 1000) {
setFps(frameCountRef.current);
frameCountRef.current = 0;
lastTimeRef.current = now;
}

requestAnimationFrame(processFrame);
};

processFrame();
}, [recognizer]);

useEffect(() => {
detect();
}, [detect]);

return (
<div className="relative">
<video ref={videoRef} autoPlay muted playsInline className="hidden" />
<canvas ref={canvasRef} width={640} height={480} />

{/* 识别结果 */}
<div className="absolute top-2 left-2 bg-black/70 text-white p-2 rounded">
<p>FPS: {fps}</p>
{gestures.map((g, i) => (
<p key={i}>
{i + 1}: {g.gesture}{(g.confidence * 100).toFixed(1)}%
</p>
))}
</div>
</div>
);
}

// 绘制手部关键点
function drawHandLandmarks(
ctx: CanvasRenderingContext2D,
result: GestureRecognizerResult
): void {
const { landmarks } = result;
if (!landmarks.length) return;

// 手部连线定义(简化)
const connections = [
[0, 1], [1, 2], [2, 3], [3, 4], // 拇指
[0, 5], [5, 6], [6, 7], [7, 8], // 食指
[0, 9], [9, 10], [10, 11], [11, 12], // 中指
[0, 13], [13, 14], [14, 15], [15, 16], // 无名指
[0, 17], [17, 18], [18, 19], [19, 20], // 小指
[5, 9], [9, 13], [13, 17], // 手掌
];

for (const handLandmarks of landmarks) {
// 画连线
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
for (const [start, end] of connections) {
const p1 = handLandmarks[start];
const p2 = handLandmarks[end];
ctx.beginPath();
ctx.moveTo(p1.x * ctx.canvas.width, p1.y * ctx.canvas.height);
ctx.lineTo(p2.x * ctx.canvas.width, p2.y * ctx.canvas.height);
ctx.stroke();
}

// 画关键点
ctx.fillStyle = '#FF0000';
for (const point of handLandmarks) {
ctx.beginPath();
ctx.arc(
point.x * ctx.canvas.width,
point.y * ctx.canvas.height,
4, 0, 2 * Math.PI
);
ctx.fill();
}
}
}

人脸网格检测

lib/face-mesh.ts
import {
FaceLandmarker,
FilesetResolver,
FaceLandmarkerResult,
} from '@mediapipe/tasks-vision';

async function createFaceDetector(): Promise<FaceLandmarker> {
const vision = await FilesetResolver.forVisionTasks(
'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
);

return FaceLandmarker.createFromOptions(vision, {
baseOptions: {
modelAssetPath:
'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task',
delegate: 'GPU',
},
runningMode: 'VIDEO',
numFaces: 1,
outputFaceBlendshapes: true, // 输出表情混合形状(52 个表情参数)
outputFacialTransformationMatrixes: true, // 输出面部变换矩阵
});
}

// 表情分类(基于 blendshapes)
function classifyExpression(
blendshapes: Array<{ categoryName: string; score: number }>
): string {
const shapes: Record<string, number> = {};
for (const bs of blendshapes) {
shapes[bs.categoryName] = bs.score;
}

// 基于关键 blendshapes 判断表情
if (shapes['mouthSmileLeft'] > 0.5 && shapes['mouthSmileRight'] > 0.5) {
return '微笑';
}
if (shapes['browDownLeft'] > 0.5 && shapes['browDownRight'] > 0.5) {
return '皱眉';
}
if (shapes['jawOpen'] > 0.5) {
return '张嘴';
}
if (shapes['eyeBlinkLeft'] > 0.5 && shapes['eyeBlinkRight'] > 0.5) {
return '眨眼';
}
return '正常';
}

八、模型优化策略

在浏览器中运行 AI 模型,模型体积和推理速度是关键瓶颈。以下是主要的优化策略:

1. 量化(Quantization)

量化方式说明体积变化精度影响
动态量化权重 INT8 + 激活运行时量化减少 ~75%极小
静态量化权重和激活都预先量化减少 ~75%小,需校准数据
量化感知训练 (QAT)训练过程中模拟量化减少 ~75%最小
GPTQ/AWQLLM 专用权重量化减少 ~75-87.5%中等

2. 模型剪枝(Pruning)

lib/model-optimization.ts
// 模型剪枝的概念说明(剪枝通常在 Python 训练端完成)
/*
模型剪枝分为:

1. 非结构化剪枝(Unstructured Pruning)
- 将小于阈值的权重置零
- 需要稀疏矩阵计算支持
- 压缩率高但硬件加速困难

2. 结构化剪枝(Structured Pruning)
- 移除整个神经元/通道/注意力头
- 直接减少计算量
- 对硬件友好
*/

// 浏览器端可以加载已剪枝的模型
// 剪枝后的模型体积更小、推理更快
interface ModelOptimizationPlan {
strategy: string;
originalSize: string;
optimizedSize: string;
speedup: string;
accuracyLoss: string;
}

const OPTIMIZATION_STRATEGIES: ModelOptimizationPlan[] = [
{
strategy: '仅量化(INT8)',
originalSize: '100 MB',
optimizedSize: '25 MB',
speedup: '2-4x',
accuracyLoss: '<1%',
},
{
strategy: '量化(INT4)+ 剪枝(50%)',
originalSize: '100 MB',
optimizedSize: '6 MB',
speedup: '4-8x',
accuracyLoss: '2-5%',
},
{
strategy: '知识蒸馏(小模型)',
originalSize: '400 MB (teacher)',
optimizedSize: '25 MB (student)',
speedup: '10-20x',
accuracyLoss: '3-8%',
},
{
strategy: '全套优化(蒸馏+量化+剪枝)',
originalSize: '400 MB',
optimizedSize: '3 MB',
speedup: '50-100x',
accuracyLoss: '5-15%',
},
];

3. 知识蒸馏(Knowledge Distillation)

模型大小 vs 精度权衡建议
  • 文本分类/情感分析:DistilBERT(66MB)足够,无需 BERT-base
  • 目标检测:MobileNet-SSD(~10MB)> YOLO-v8n(~12MB)> DETR(~160MB)
  • 文本嵌入:all-MiniLM-L6-v2(~23MB)是体积和质量的最佳平衡
  • LLM:Qwen2.5-1.5B-q4(~1GB)是当前浏览器中运行 LLM 的实用下限

九、渐进式模型加载

模型文件通常较大(10MB - 4GB),需要精心设计加载策略以提升用户体验。

lib/progressive-model-loading.ts
// 渐进式模型加载器 — 带进度、缓存、预热
class ProgressiveModelLoader {
private cacheStoreName = 'ai-models';

// 带进度的模型下载
async downloadWithProgress(
url: string,
onProgress: (loaded: number, total: number) => void
): Promise<ArrayBuffer> {
const response = await fetch(url);
const contentLength = Number(response.headers.get('content-length')) || 0;
const reader = response.body!.getReader();
const chunks: Uint8Array[] = [];
let loaded = 0;

while (true) {
const { done, value } = await reader.read();
if (done) break;

chunks.push(value);
loaded += value.length;
onProgress(loaded, contentLength);
}

// 合并所有 chunks
const buffer = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
buffer.set(chunk, offset);
offset += chunk.length;
}
return buffer.buffer;
}

// 使用 Cache API 缓存模型
async loadWithCache(
modelUrl: string,
onProgress: (stage: string, progress: number) => void
): Promise<ArrayBuffer> {
const cache = await caches.open(this.cacheStoreName);
const cached = await cache.match(modelUrl);

if (cached) {
onProgress('从缓存加载', 1);
return cached.arrayBuffer();
}

onProgress('下载模型', 0);
const buffer = await this.downloadWithProgress(modelUrl, (loaded, total) => {
onProgress('下载模型', total > 0 ? loaded / total : 0);
});

// 存入缓存
await cache.put(modelUrl, new Response(buffer.slice(0)));
onProgress('缓存完成', 1);

return buffer;
}

// 使用 IndexedDB 缓存(适合大模型,不受 Cache API 大小限制)
async loadWithIndexedDB(
modelUrl: string,
modelId: string,
onProgress: (stage: string, progress: number) => void
): Promise<ArrayBuffer> {
// 先检查 IndexedDB
const cachedBuffer = await this.getFromIDB(modelId);
if (cachedBuffer) {
onProgress('从 IndexedDB 加载', 1);
return cachedBuffer;
}

onProgress('下载模型', 0);
const buffer = await this.downloadWithProgress(modelUrl, (loaded, total) => {
onProgress('下载模型', total > 0 ? loaded / total : 0);
});

// 存入 IndexedDB
await this.saveToIDB(modelId, buffer);
onProgress('缓存完成', 1);

return buffer;
}

private getFromIDB(key: string): Promise<ArrayBuffer | null> {
return new Promise((resolve, reject) => {
const request = indexedDB.open('ModelCache', 1);
request.onupgradeneeded = () => {
request.result.createObjectStore('models');
};
request.onsuccess = () => {
const tx = request.result.transaction('models', 'readonly');
const store = tx.objectStore('models');
const getReq = store.get(key);
getReq.onsuccess = () => resolve(getReq.result ?? null);
getReq.onerror = () => reject(getReq.error);
};
});
}

private saveToIDB(key: string, data: ArrayBuffer): Promise<void> {
return new Promise((resolve, reject) => {
const request = indexedDB.open('ModelCache', 1);
request.onupgradeneeded = () => {
request.result.createObjectStore('models');
};
request.onsuccess = () => {
const tx = request.result.transaction('models', 'readwrite');
const store = tx.objectStore('models');
store.put(data, key);
tx.oncomplete = () => resolve();
tx.onerror = () => reject(tx.error);
};
});
}
}

加载进度 UI 组件

components/ModelLoadingUI.tsx
import { useState, useCallback } from 'react';

interface LoadingState {
stage: string;
progress: number;
error?: string;
}

export function ModelLoadingUI({
onLoaded,
}: {
onLoaded: (buffer: ArrayBuffer) => void;
}) {
const [state, setState] = useState<LoadingState | null>(null);
const loader = new ProgressiveModelLoader();

const startLoading = useCallback(async (modelUrl: string) => {
try {
const buffer = await loader.loadWithCache(modelUrl, (stage, progress) => {
setState({ stage, progress });
});

// 模型预热(首次推理通常较慢,先运行一次空推理)
setState({ stage: '模型预热中...', progress: 1 });
onLoaded(buffer);
} catch (error) {
setState({
stage: '加载失败',
progress: 0,
error: (error as Error).message,
});
}
}, [onLoaded]);

if (!state) {
return (
<button onClick={() => startLoading('/models/model.onnx')}>
加载 AI 模型
</button>
);
}

return (
<div className="w-full max-w-md mx-auto p-4">
<p className="text-sm text-gray-600 mb-2">{state.stage}</p>
<div className="w-full bg-gray-200 rounded-full h-3">
<div
className="bg-blue-500 h-3 rounded-full transition-all duration-300"
style={{ width: `${state.progress * 100}%` }}
/>
</div>
<p className="text-xs text-gray-400 mt-1">
{(state.progress * 100).toFixed(1)}%
</p>
{state.error && (
<p className="text-red-500 text-sm mt-2">{state.error}</p>
)}
</div>
);
}

模型预热策略

lib/model-warmup.ts
// 模型预热:首次推理通常比后续慢 2-5 倍(JIT 编译、GPU 管线初始化等)
// 在后台预先运行一次推理可以消除这个延迟

async function warmupONNXModel(session: ort.InferenceSession): Promise<void> {
const inputName = session.inputNames[0];
const inputShape = session.inputNames.length > 0
? (session as any)._model?.graph?.input?.[0]?.type?.tensorType?.shape?.dim?.map(
(d: any) => d.dimValue || 1
) ?? [1, 3, 224, 224]
: [1, 3, 224, 224];

// 创建全零的 dummy 输入
const dummyData = new Float32Array(
inputShape.reduce((a: number, b: number) => a * b, 1)
);
const dummyTensor = new ort.Tensor('float32', dummyData, inputShape);

// 运行一次推理(丢弃结果)
await session.run({ [inputName]: dummyTensor });
}

十、混合推理架构

混合推理(Hybrid Inference)将端侧推理和云端推理结合,利用各自优势:

lib/hybrid-inference.ts
import { pipeline } from '@huggingface/transformers';

interface InferenceResult {
source: 'local' | 'cloud';
result: unknown;
latency: number;
confidence?: number;
}

class HybridInferenceEngine {
private localClassifier: Awaited<ReturnType<typeof pipeline>> | null = null;
private localEmbedder: Awaited<ReturnType<typeof pipeline>> | null = null;
private confidenceThreshold = 0.85;

async init(): Promise<void> {
// 并行加载本地模型
const [classifier, embedder] = await Promise.all([
pipeline('text-classification', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'),
pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'),
]);
this.localClassifier = classifier;
this.localEmbedder = embedder;
}

// 策略1:置信度路由 — 本地推理置信度低则转云端
async classifyWithFallback(text: string): Promise<InferenceResult> {
const start = performance.now();

// 先尝试本地推理
if (this.localClassifier) {
const localResult = await this.localClassifier(text) as any[];
const topResult = localResult[0];

if (topResult.score >= this.confidenceThreshold) {
// 置信度足够,直接返回本地结果
return {
source: 'local',
result: topResult,
latency: performance.now() - start,
confidence: topResult.score,
};
}
}

// 本地置信度不足或模型未加载,回退到云端
const cloudResult = await this.callCloudAPI(text);
return {
source: 'cloud',
result: cloudResult,
latency: performance.now() - start,
};
}

// 策略2:本地预处理 + 云端精处理
async analyzeWithPreprocessing(text: string): Promise<InferenceResult> {
const start = performance.now();

// 本地生成文本嵌入(用于 RAG 检索)
let embedding: number[] | null = null;
if (this.localEmbedder) {
const output = await this.localEmbedder(text, {
pooling: 'mean',
normalize: true,
});
embedding = Array.from(output.data);
}

// 将嵌入向量发送给云端(用于相似度搜索/RAG)
const cloudResult = await this.callCloudAPIWithEmbedding(text, embedding);
return {
source: 'cloud',
result: cloudResult,
latency: performance.now() - start,
};
}

// 策略3:任务复杂度路由
async smartRoute(task: {
type: 'classify' | 'generate' | 'summarize' | 'embed';
input: string;
}): Promise<InferenceResult> {
const start = performance.now();

switch (task.type) {
case 'classify':
case 'embed':
// 分类和嵌入在本地完成
return this.classifyWithFallback(task.input);

case 'generate':
case 'summarize':
// 生成和摘要发送到云端
const result = await this.callCloudAPI(task.input);
return {
source: 'cloud',
result,
latency: performance.now() - start,
};

default:
throw new Error(`未知任务类型: ${task.type}`);
}
}

private async callCloudAPI(text: string): Promise<unknown> {
const response = await fetch('/api/ai/inference', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text }),
});
return response.json();
}

private async callCloudAPIWithEmbedding(
text: string,
embedding: number[] | null
): Promise<unknown> {
const response = await fetch('/api/ai/inference', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text, embedding }),
});
return response.json();
}
}
混合推理适用场景
  • 表单校验/文本分类:本地推理,低延迟,无成本
  • 搜索建议:本地嵌入 + 本地向量搜索
  • RAG 应用:本地嵌入生成 + 云端 LLM 生成回答
  • 图片预筛:本地目标检测/分类 + 云端 Vision 精分析
  • 敏感数据:本地脱敏/分类后再发送云端

十一、Web Worker 深度集成

AI 推理计算量大,必须放在 Web Worker 中避免阻塞主线程。以下是使用 SharedArrayBuffer 和 Comlink 的进阶方案。

Comlink 是 Google Chrome 团队开发的库,将 Worker 通信封装为简单的函数调用:

workers/ai-worker-comlink.ts
// Worker 端
import * as Comlink from 'comlink';
import { pipeline } from '@huggingface/transformers';

class AIWorkerAPI {
private classifier: Awaited<ReturnType<typeof pipeline>> | null = null;
private embedder: Awaited<ReturnType<typeof pipeline>> | null = null;

async initClassifier(
onProgress: (progress: number) => void
): Promise<void> {
this.classifier = await pipeline(
'text-classification',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english',
{
progress_callback: (data: { status: string; progress?: number }) => {
if (data.progress !== undefined) onProgress(data.progress);
},
}
);
}

async initEmbedder(): Promise<void> {
this.embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2'
);
}

async classify(text: string): Promise<Array<{ label: string; score: number }>> {
if (!this.classifier) throw new Error('分类器未初始化');
return this.classifier(text) as any;
}

async embed(text: string): Promise<Float32Array> {
if (!this.embedder) throw new Error('嵌入模型未初始化');
const output = await this.embedder(text, { pooling: 'mean', normalize: true });
return output.data as Float32Array;
}

// 批量处理(利用 SharedArrayBuffer 减少数据拷贝)
async batchEmbed(
texts: string[],
resultBuffer: SharedArrayBuffer // 共享内存
): Promise<void> {
if (!this.embedder) throw new Error('嵌入模型未初始化');

const resultView = new Float32Array(resultBuffer);
const embeddingDim = 384; // all-MiniLM-L6-v2 的维度

for (let i = 0; i < texts.length; i++) {
const output = await this.embedder(texts[i], {
pooling: 'mean',
normalize: true,
});
const embedding = output.data as Float32Array;
resultView.set(embedding, i * embeddingDim);
}
}
}

Comlink.expose(new AIWorkerAPI());
hooks/useAIWorkerComlink.ts
import { useEffect, useRef, useState, useCallback } from 'react';
import * as Comlink from 'comlink';

interface AIWorkerAPI {
initClassifier(onProgress: (progress: number) => void): Promise<void>;
initEmbedder(): Promise<void>;
classify(text: string): Promise<Array<{ label: string; score: number }>>;
embed(text: string): Promise<Float32Array>;
batchEmbed(texts: string[], resultBuffer: SharedArrayBuffer): Promise<void>;
}

export function useAIWorker() {
const workerRef = useRef<Worker | null>(null);
const apiRef = useRef<Comlink.Remote<AIWorkerAPI> | null>(null);
const [isReady, setIsReady] = useState(false);
const [loadProgress, setLoadProgress] = useState(0);

useEffect(() => {
const worker = new Worker(
new URL('../workers/ai-worker-comlink.ts', import.meta.url),
{ type: 'module' }
);
workerRef.current = worker;
apiRef.current = Comlink.wrap<AIWorkerAPI>(worker);

// 初始化模型
apiRef.current
.initClassifier(Comlink.proxy((progress: number) => {
setLoadProgress(progress);
}))
.then(() => setIsReady(true));

return () => worker.terminate();
}, []);

const classify = useCallback(async (text: string) => {
if (!apiRef.current) throw new Error('Worker 未就绪');
return apiRef.current.classify(text); // 像调用本地函数一样简单
}, []);

const batchEmbed = useCallback(async (texts: string[]) => {
if (!apiRef.current) throw new Error('Worker 未就绪');

const embeddingDim = 384;
// 使用 SharedArrayBuffer 避免大数据拷贝
const buffer = new SharedArrayBuffer(texts.length * embeddingDim * 4);
await apiRef.current.batchEmbed(texts, buffer);
return new Float32Array(buffer);
}, []);

return { isReady, loadProgress, classify, batchEmbed };
}
SharedArrayBuffer 安全要求

使用 SharedArrayBuffer 需要页面启用 COOP/COEP 安全头:

Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp

这些头可能导致部分第三方资源加载失败,需要在 CDN 资源上配置 Cross-Origin-Resource-Policy: cross-origin

十二、端侧 vs 云端对比

维度端侧推理云端推理
隐私数据不离开设备数据需发送到服务器
延迟低(无网络延迟)高(网络 + 排队)
离线支持需要网络
模型能力小模型(<10B 参数)大模型(100B+ 参数)
费用免费(使用用户算力)按 token 付费
首次加载慢(需下载模型 10MB-4GB)快(仅发送请求)
兼容性WebGPU 要求较新浏览器全平台
维护模型更新需用户重新下载服务端透明更新
适用场景实时检测、隐私数据、离线、高频低复杂度任务复杂推理、长文本生成、RAG

十三、TensorFlow.js

TensorFlow.js 是 Google 开发的最成熟的 Web ML 框架,拥有丰富的预训练模型生态:

lib/tfjs-example.ts
import * as tf from '@tensorflow/tfjs';

// 1. 加载预训练模型
async function loadModel() {
const model = await tf.loadLayersModel('/models/sentiment/model.json');
return model;
}

// 2. 图片分类
import * as mobilenet from '@tensorflow-models/mobilenet';

async function classifyImage(imageElement: HTMLImageElement) {
const model = await mobilenet.load({ version: 2, alpha: 1.0 });
const predictions = await model.classify(imageElement);

return predictions.map(p => ({
className: p.className,
probability: (p.probability * 100).toFixed(1) + '%',
}));
// [{ className: '金毛猎犬', probability: '89.2%' }]
}

// 3. 姿态检测
import * as poseDetection from '@tensorflow-models/pose-detection';

async function detectPose(video: HTMLVideoElement) {
const detector = await poseDetection.createDetector(
poseDetection.SupportedModels.MoveNet,
{ modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }
);
const poses = await detector.estimatePoses(video);
return poses;
}

常见面试问题

Q1: 浏览器中运行 AI 模型有哪些方案?各自的特点是什么?

答案

方案开发者加速后端模型来源适用场景
TensorFlow.jsGoogleWebGPU/WebGL/WASMTF 模型、TFJS 模型训练+推理,生态最丰富
ONNX Runtime WebMicrosoftWebGPU/WebNN/WASMONNX 格式(PyTorch/TF 转换)跨框架通用推理
Transformers.jsHugging FaceWebGPU/WASMHugging Face 模型NLP/CV 任务,开箱即用
WebLLMMLC-AIWebGPULlama/Mistral/Phi 等 LLM浏览器中运行完整 LLM
MediaPipeGoogleGPU delegate/WASM预训练视觉模型实时视觉处理(手势/人脸/姿态)

加速层优先级:WebGPU(GPU)> WebNN(NPU)> WebAssembly(CPU)

选择建议:

  • 快速原型:Transformers.js(pipeline API 最简单)
  • 生产部署:ONNX Runtime Web(跨框架、执行提供者丰富)
  • 实时视觉:MediaPipe(专为实时优化)
  • 浏览器 LLM:WebLLM(唯一可行方案)

Q2: WebGPU Compute Shader 和 WebGL 在 AI 计算上有什么区别?

答案

特性WebGPU Compute ShaderWebGL (用于 AI)
设计目标通用计算(GPGPU)图形渲染
计算模型原生 Compute Shader,直接操作 buffer需要将计算映射为纹理渲染
数据类型f32, f16, i32, u32, 支持结构体纹理像素(RGBA float)
同步原子操作、workgroup barrier无计算同步原语
内存访问Storage Buffer 直接读写纹理采样,受限于纹理尺寸
性能比 WebGL 快 3-10 倍(矩阵运算)受纹理映射开销限制
API 复杂度较高(类似 Vulkan)中等

WebGPU Compute Shader 的核心优势:

  1. 原生支持 GPGPU:不需要把矩阵乘法伪装为纹理操作
  2. workgroup 共享内存:线程组内共享数据,减少全局内存访问
  3. 灵活的数据布局:Storage Buffer 支持任意结构体
  4. 原子操作:支持线程安全的累加/比较等操作

Q3: 如何将 PyTorch/TensorFlow 模型转换并优化用于浏览器推理?

答案

转换流程

PyTorch 模型 (.pth)
→ torch.onnx.export() → model.onnx
→ 量化工具 → model_int8.onnx
→ 浏览器 ONNX Runtime Web 加载

TensorFlow 模型 (.pb / SavedModel)
→ tf2onnx → model.onnx → 同上

→ tensorflowjs_converter → TFJS 格式 → TensorFlow.js 加载

优化步骤

  1. 模型转换:PyTorch 用 torch.onnx.export(),TensorFlow 用 tf2onnx
  2. 图优化:ONNX Runtime 的 GraphOptimizationLevel.ORT_ENABLE_ALL 自动进行算子融合、常量折叠
  3. 量化:动态量化(INT8)最简单,quantize_dynamic() 一行代码
  4. 分片:大模型拆分为多个分片(每片 <50MB),支持并行下载
  5. 测试:量化后在目标硬件上验证精度损失是否可接受

关键配置

// opset_version 选择 17+(支持更多操作)
// dynamic_axes 设为动态 batch(支持不同输入大小)
// 量化推荐 INT8 动态量化(最简单,精度损失最小)

Q4: 如何实现渐进式模型加载并展示进度?

答案

渐进式模型加载包含四个阶段:

  1. 检查缓存:先查 Cache API 或 IndexedDB 是否有缓存
  2. 下载模型:使用 fetch + ReadableStream 实现带进度的下载
  3. 缓存模型:下载后存入 Cache API(小模型)或 IndexedDB(大模型)
  4. 模型预热:首次推理较慢,后台运行空推理消除 JIT 编译延迟

核心实现:

// 带进度的下载
const response = await fetch(modelUrl);
const total = Number(response.headers.get('content-length'));
const reader = response.body!.getReader();
let loaded = 0;

while (true) {
const { done, value } = await reader.read();
if (done) break;
loaded += value.length;
onProgress(loaded / total); // 更新进度
}

缓存选择

  • Cache API:适合小模型(<200MB),API 简单,自动管理
  • IndexedDB:适合大模型(>200MB),无大小限制,手动管理
  • Transformers.js 默认使用 Cache API,WebLLM 使用 IndexedDB

Q5: 对比 TensorFlow.js、ONNX Runtime Web 和 Transformers.js

答案

维度TensorFlow.jsONNX Runtime WebTransformers.js
开发者GoogleMicrosoftHugging Face
模型格式TF/TFJSONNXONNX(自动转换)
加速后端WebGPU/WebGL/WASMWebGPU/WebNN/WASMWebGPU/WASM
模型训练支持不支持不支持
模型生态TF HubONNX Model ZooHugging Face Hub(最大)
API 易用性中等低(需手动 Tensor)高(pipeline API)
模型加载需转为 TFJS 格式需转为 ONNX 格式自动下载转换
包大小~500KB~200KB~300KB(不含模型)
WebNN 支持不支持支持不支持
社区活跃度极高

选择建议

  • 需要浏览器端训练 → TensorFlow.js
  • 已有 PyTorch/TF 模型需部署 → ONNX Runtime Web
  • 快速使用 NLP/CV 模型 → Transformers.js
  • 需要 WebNN/NPU 加速 → ONNX Runtime Web

Q6: 如何在浏览器中实现实时目标检测?

答案

实时目标检测需要解决三个核心问题:模型选择帧处理循环性能优化

// 使用 Transformers.js 的目标检测 pipeline
import { pipeline } from '@huggingface/transformers';

// 1. 加载模型(一次性)
const detector = await pipeline(
'object-detection',
'Xenova/detr-resnet-50'
);

// 2. 帧处理循环
function startDetection(video: HTMLVideoElement, canvas: HTMLCanvasElement) {
const ctx = canvas.getContext('2d')!;

async function processFrame() {
// 将视频帧绘制到 canvas
ctx.drawImage(video, 0, 0);

// 推理
const results = await detector(canvas.toDataURL(), {
threshold: 0.7, // 置信度阈值
});

// 绘制检测框
for (const { label, score, box } of results) {
ctx.strokeStyle = '#00FF00';
ctx.lineWidth = 2;
ctx.strokeRect(box.xmin, box.ymin, box.xmax - box.xmin, box.ymax - box.ymin);
ctx.fillText(`${label} ${(score * 100).toFixed(0)}%`, box.xmin, box.ymin - 5);
}

requestAnimationFrame(processFrame);
}

processFrame();
}

性能优化要点

  1. Worker 隔离:推理放在 Web Worker,避免阻塞 UI
  2. 帧率控制:不需要每帧都推理,可以每 2-3 帧推理一次
  3. 模型选择:DETR (~160MB) 精度高但慢,MobileNet-SSD (~10MB) 速度快但精度低
  4. 输入分辨率:降低输入图片分辨率可大幅提升速度
  5. WebGPU 加速:确保使用 WebGPU 后端而非 WASM

或使用 MediaPipe 获得更好的实时性能(专为实时优化,FPS 更高)。

Q7: 什么是混合推理?什么场景下应该使用?

答案

混合推理是将端侧推理和云端推理结合的架构模式,根据任务复杂度、置信度、网络状态等因素动态选择推理位置。

三种路由策略

策略原理示例
置信度路由本地推理置信度低于阈值时转云端情感分析置信度 <85% 转 GPT-4o
任务复杂度路由简单任务本地、复杂任务云端分类/嵌入本地,生成/摘要云端
预处理+精处理本地做数据预处理后发送云端本地生成嵌入向量 → 云端 RAG

适用场景

  1. RAG 应用:本地嵌入 + 云端 LLM
  2. 图片审核:本地 NSFW 检测 + 可疑内容发云端人工审核
  3. 离线优先:有网络用云端,离线用本地小模型
  4. 成本优化:高频简单任务本地处理,减少 API 调用

Q8: 如何管理浏览器中的 AI 模型缓存(Cache API / IndexedDB)?

答案

方案适用场景大小限制API 复杂度
Cache API小模型(<200MB)浏览器配额(通常 >1GB)简单
IndexedDB大模型(>200MB)更大配额中等
Origin Private File System超大模型磁盘空间较新 API

关键实践:

// Cache API 缓存
const cache = await caches.open('ai-models');
const cached = await cache.match(modelUrl);
if (cached) return cached.arrayBuffer();

// 下载并缓存
const response = await fetch(modelUrl);
await cache.put(modelUrl, response.clone());
return response.arrayBuffer();

注意事项

  1. 存储配额:使用 navigator.storage.estimate() 检查剩余空间
  2. 版本管理:模型 URL 带版本号(如 model_v2_int8.onnx),更新时清理旧版本
  3. 用户提示:大模型下载前提示用户("即将下载 1.8GB 模型")
  4. 清理策略:定期清理不再使用的模型缓存
  5. 持久化存储:调用 navigator.storage.persist() 防止浏览器自动清理

Q9: WebNN 和 WebGPU 有什么区别?各自在 AI 推理中的角色是什么?

答案

维度WebGPUWebNN
设计目标通用 GPU 编程(图形+计算)专用于神经网络推理
硬件访问GPUNPU、GPU、CPU(统一接口)
API 粒度低级(写 Shader、管理 Buffer)高级(定义计算图、执行推理)
优化方式手动(Shader 优化、内存管理)自动(OS 层面优化,利用专用硬件)
功耗较高(GPU 满载)较低(NPU 能效比高)
浏览器支持Chrome 113+ 正式Chrome Origin Trial
主要受益所有需要 GPU 计算的任务神经网络推理,尤其移动端

关系:WebNN 和 WebGPU 不是竞争关系而是互补。WebNN 通过操作系统的原生 ML 框架(DirectML/CoreML/NNAPI)利用 NPU 硬件,功耗更低;WebGPU 提供更灵活的 GPU 计算能力。ONNX Runtime Web 同时支持两者作为执行提供者。

未来趋势:移动设备上 NPU 普及(高通骁龙、Apple Neural Engine),WebNN 的价值将更加突出。

Q10: 端侧 AI 推理会阻塞主线程吗?如何优化?

答案

会阻塞。AI 模型推理涉及大量计算,如果在主线程运行会导致页面卡顿和无响应。

解决方案(按优先级排序):

  1. Web Worker(最重要):将模型加载和推理放在 Worker 线程

    • 使用 Comlink 简化 Worker 通信(像调用本地函数一样)
    • 大数据传输使用 Transferable ObjectsSharedArrayBuffer 减少拷贝
  2. WebGPU 加速:GPU 计算本身在 GPU 上执行,不阻塞 CPU 主线程

    • 但 JS 层的数据准备和结果读取仍在主线程
    • 仍建议配合 Worker 使用
  3. 模型量化:INT8/INT4 量化减少计算量,推理更快

  4. 帧率控制:实时检测不需要每帧推理,可每 2-3 帧推理一次

更多 Worker 使用细节参考 Web Workers。 更多性能优化策略参考 AI 应用性能优化

Q11: Transformers.js 支持哪些 AI 任务?模型如何缓存?

答案

支持的任务(使用 pipeline API):

任务类别pipeline 名称典型模型模型大小
文本分类text-classificationdistilbert-sst-2~67MB
命名实体识别token-classificationbert-base-NER~110MB
问答question-answeringdistilbert-squad~67MB
文本摘要summarizationdistilbart-cnn~300MB
翻译translationopus-mt-xx-xx~150MB
文本生成text2text-generationflan-t5-small~77MB
零样本分类zero-shot-classificationnli-deberta~50MB
特征提取feature-extractionall-MiniLM-L6-v2~23MB
图片分类image-classificationvit-base~86MB
目标检测object-detectiondetr-resnet-50~160MB
图像分割image-segmentationdetr-panoptic~160MB

缓存机制

// Transformers.js 默认使用 Cache API
// 首次下载模型后自动缓存,下次加载直接从缓存读取

// 查看缓存大小
const cache = await caches.open('transformers-cache');
const keys = await cache.keys();
console.log(`已缓存 ${keys.length} 个文件`);

// 预加载模型(不立即使用)
await pipeline('text-classification', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english');
// 之后调用时直接从缓存加载,秒级就绪

Q12: 什么场景适合端侧 AI 而不是云端?

答案

场景原因推荐方案
隐私敏感数据医疗影像、金融数据不应离开设备ONNX Runtime Web + WebGPU
实时交互AR 滤镜、手势识别需要 <50ms 延迟MediaPipe
离线场景PWA、弱网环境、飞行模式Transformers.js + Cache API
成本控制高频低复杂度任务避免 API 费用本地小模型
预处理客户端做初步分类/嵌入再发云端混合推理架构
合规要求GDPR 等法规限制数据出境端侧推理

不适合端侧的场景

  • 需要大模型能力(>10B 参数):当前浏览器只能运行小模型
  • 长文本生成/复杂推理:小模型质量不足
  • 首次体验要求快:模型下载需要时间

Q13: 如何在前端实现"本地知识库"功能(不依赖后端向量库)?

答案

前端本地知识库的核心架构是 IndexedDB 存储 + 浏览器端 Embedding + 余弦相似度搜索,完全在浏览器中运行,不需要后端向量数据库。

技术方案

lib/local-knowledge-base.ts
import { pipeline, type FeatureExtractionPipeline } from '@huggingface/transformers';

interface Document {
id: string;
content: string;
metadata?: Record<string, string>;
embedding?: Float32Array; // 向量存储在 IndexedDB
}

interface SearchResult {
document: Document;
score: number; // 余弦相似度 0-1
}

class LocalKnowledgeBase {
private embedder: FeatureExtractionPipeline | null = null;
private db: IDBDatabase | null = null;
private readonly DB_NAME = 'knowledge-base';
private readonly STORE_NAME = 'documents';

// 1. 初始化嵌入模型(~23MB,首次下载后缓存)
async init(): Promise<void> {
// all-MiniLM-L6-v2:384 维向量,轻量且效果好
this.embedder = await pipeline(
'feature-extraction',
'Xenova/all-MiniLM-L6-v2',
{ device: 'webgpu' } // 优先 WebGPU 加速
) as FeatureExtractionPipeline;

// 打开 IndexedDB
this.db = await new Promise((resolve, reject) => {
const request = indexedDB.open(this.DB_NAME, 1);
request.onupgradeneeded = () => {
const db = request.result;
if (!db.objectStoreNames.contains(this.STORE_NAME)) {
db.createObjectStore(this.STORE_NAME, { keyPath: 'id' });
}
};
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}

// 2. 添加文档:生成 Embedding 并存储
async addDocument(content: string, metadata?: Record<string, string>): Promise<string> {
if (!this.embedder || !this.db) throw new Error('未初始化');

const id = crypto.randomUUID();
// 生成文本的向量表示
const output = await this.embedder(content, { pooling: 'mean', normalize: true });
const embedding = new Float32Array(output.data as ArrayBuffer);

const doc: Document = { id, content, metadata, embedding };

// 存入 IndexedDB
await new Promise<void>((resolve, reject) => {
const tx = this.db!.transaction(this.STORE_NAME, 'readwrite');
tx.objectStore(this.STORE_NAME).put(doc);
tx.oncomplete = () => resolve();
tx.onerror = () => reject(tx.error);
});

return id;
}

// 3. 语义搜索:计算查询向量与所有文档的余弦相似度
async search(query: string, topK = 5): Promise<SearchResult[]> {
if (!this.embedder || !this.db) throw new Error('未初始化');

// 生成查询向量
const output = await this.embedder(query, { pooling: 'mean', normalize: true });
const queryVec = new Float32Array(output.data as ArrayBuffer);

// 从 IndexedDB 读取所有文档
const docs = await new Promise<Document[]>((resolve, reject) => {
const tx = this.db!.transaction(this.STORE_NAME, 'readonly');
const request = tx.objectStore(this.STORE_NAME).getAll();
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});

// 计算余弦相似度并排序
const results: SearchResult[] = docs
.filter(doc => doc.embedding)
.map(doc => ({
document: { ...doc, embedding: undefined },
score: cosineSimilarity(queryVec, doc.embedding!),
}))
.sort((a, b) => b.score - a.score)
.slice(0, topK);

return results;
}
}

// 余弦相似度(向量已归一化时等于点积)
function cosineSimilarity(a: Float32Array, b: Float32Array): number {
let dot = 0;
for (let i = 0; i < a.length; i++) dot += a[i] * b[i];
return dot; // 归一化后 dot product = cosine similarity
}

性能优化要点

优化点方案效果
Embedding 计算放在 Web Worker 中不阻塞 UI
大量文档搜索使用 HNSW 索引(如 hnswlib-wasmO(log n) 搜索,替代 O(n) 暴力搜索
文档分块长文档按段落/句子拆分(~200-500 字/块)提高检索精度
WebGPU 加速{ device: 'webgpu' }Embedding 速度提升 3-10 倍

局限性

  • 向量维度 384,文档量 >10 万时 IndexedDB 搜索变慢(建议引入 HNSW)
  • 浏览器端只能用小型 Embedding 模型(~23MB),语义理解不如 OpenAI text-embedding-3
  • 适合个人笔记、本地文档检索等场景;企业级 RAG 仍需后端向量数据库

更多 RAG 架构设计参考 RAG 检索增强生成。 更多向量搜索原理参考 向量搜索与 Embedding

相关链接