为什么在浏览器跑机器学习
传统 ML 推理流程:用户上传数据 → 发送服务器 → 模型计算 → 返回结果
浏览器端 ML:用户数据不离开浏览器 → 本地模型计算 → 即时结果
优势
| 方面 | 服务端推理 | 浏览器端推理 |
|---|---|---|
| 隐私 | 数据上传服务器 | 数据不离开设备 |
| 延迟 | 网络往返 | 即时响应 |
| 成本 | 服务器算力 | 用户设备算力 |
| 离线 | 需要网络 | 支持离线 |
适用场景
- 图像分类/目标检测
- 姿态识别
- 文本情感分析
- 实时滤镜/特效
- 语音识别(简单场景)
快速开始
安装
# npm
npm install @tensorflow/tfjs
# CDN(开发测试)
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
基础用法
import * as tf from '@tensorflow/tfjs'
// 创建张量
const tensor = tf.tensor([1, 2, 3, 4])
console.log(tensor.shape) // [4]
// 基本运算
const a = tf.tensor([1, 2, 3])
const b = tf.tensor([4, 5, 6])
const sum = a.add(b)
sum.print() // [5, 7, 9]
// 矩阵运算
const matrix = tf.tensor2d([[1, 2], [3, 4]])
const transposed = matrix.transpose()
transposed.print()
// 释放内存(重要!)
tensor.dispose()
加载预训练模型
官方预训练模型
TensorFlow.js 提供了多个开箱即用的模型:
// 图像分类
import * as mobilenet from '@tensorflow-models/mobilenet'
const model = await mobilenet.load()
const predictions = await model.classify(imageElement)
// [{ className: 'cat', probability: 0.95 }, ...]
// 目标检测
import * as cocoSsd from '@tensorflow-models/coco-ssd'
const detector = await cocoSsd.load()
const objects = await detector.detect(imageElement)
// [{ class: 'person', bbox: [x, y, width, height], score: 0.89 }, ...]
// 姿态识别
import * as poseDetection from '@tensorflow-models/pose-detection'
const detector = await poseDetection.createDetector(
poseDetection.SupportedModels.MoveNet
)
const poses = await detector.estimatePoses(videoElement)
加载自定义模型
// 从 URL 加载
const model = await tf.loadLayersModel('https://example.com/model/model.json')
// 从本地文件加载
const model = await tf.loadLayersModel('localstorage://my-model')
// 从 IndexedDB 加载
const model = await tf.loadLayersModel('indexeddb://my-model')
模型格式转换
将 Python 训练的模型转换为 TensorFlow.js 格式:
# 安装转换工具
pip install tensorflowjs
# Keras 模型转换
tensorflowjs_converter --input_format keras \
model.h5 \
./web_model
# SavedModel 转换
tensorflowjs_converter --input_format tf_saved_model \
./saved_model \
./web_model
实战案例
案例 1:图像分类组件
<script setup lang="ts">
import * as mobilenet from '@tensorflow-models/mobilenet'
const imageRef = ref<HTMLImageElement>()
const predictions = ref<Prediction[]>([])
const isLoading = ref(false)
const model = ref<mobilenet.MobileNet>()
// 加载模型
onMounted(async () => {
isLoading.value = true
model.value = await mobilenet.load()
isLoading.value = false
})
// 处理图片上传
async function handleImageUpload(event: Event) {
const file = (event.target as HTMLInputElement).files?.[0]
if (!file) return
const url = URL.createObjectURL(file)
imageRef.value!.src = url
// 等待图片加载
await new Promise(resolve => {
imageRef.value!.onload = resolve
})
// 分类
isLoading.value = true
predictions.value = await model.value!.classify(imageRef.value!)
isLoading.value = false
}
</script>
<template>
<div class="p-4">
<input
type="file"
accept="image/*"
@change="handleImageUpload"
:disabled="!model"
/>
<div class="mt-4">
<img ref="imageRef" class="max-w-md" />
</div>
<div v-if="isLoading" class="mt-4">分析中...</div>
<div v-if="predictions.length" class="mt-4 space-y-2">
<div
v-for="pred in predictions"
:key="pred.className"
class="flex items-center gap-2"
>
<span class="font-medium">{{ pred.className }}</span>
<div class="flex-1 bg-gray-200 rounded-full h-2">
<div
class="bg-blue-500 h-2 rounded-full"
:style="{ width: `${pred.probability * 100}%` }"
></div>
</div>
<span>{{ (pred.probability * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
</template>
案例 2:实时姿态检测
<script setup lang="ts">
import * as poseDetection from '@tensorflow-models/pose-detection'
const videoRef = ref<HTMLVideoElement>()
const canvasRef = ref<HTMLCanvasElement>()
const detector = ref<poseDetection.PoseDetector>()
let animationId: number
onMounted(async () => {
// 加载模型
detector.value = await poseDetection.createDetector(
poseDetection.SupportedModels.MoveNet,
{ modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }
)
// 获取摄像头
const stream = await navigator.mediaDevices.getUserMedia({ video: true })
videoRef.value!.srcObject = stream
videoRef.value!.onloadedmetadata = () => {
videoRef.value!.play()
detectPose()
}
})
async function detectPose() {
if (!detector.value || !videoRef.value || !canvasRef.value) return
const poses = await detector.value.estimatePoses(videoRef.value)
drawPoses(poses)
animationId = requestAnimationFrame(detectPose)
}
function drawPoses(poses: poseDetection.Pose[]) {
const ctx = canvasRef.value!.getContext('2d')!
const { width, height } = videoRef.value!
canvasRef.value!.width = width
canvasRef.value!.height = height
// 绘制视频帧
ctx.drawImage(videoRef.value!, 0, 0)
// 绘制关键点
for (const pose of poses) {
for (const keypoint of pose.keypoints) {
if (keypoint.score && keypoint.score > 0.5) {
ctx.beginPath()
ctx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI)
ctx.fillStyle = 'red'
ctx.fill()
}
}
// 绘制骨骼连线
const edges = poseDetection.util.getAdjacentPairs(
poseDetection.SupportedModels.MoveNet
)
for (const [i, j] of edges) {
const kp1 = pose.keypoints[i]
const kp2 = pose.keypoints[j]
if (kp1.score! > 0.5 && kp2.score! > 0.5) {
ctx.beginPath()
ctx.moveTo(kp1.x, kp1.y)
ctx.lineTo(kp2.x, kp2.y)
ctx.strokeStyle = 'green'
ctx.lineWidth = 2
ctx.stroke()
}
}
}
}
onUnmounted(() => {
cancelAnimationFrame(animationId)
})
</script>
<template>
<div class="relative">
<video ref="videoRef" class="hidden" />
<canvas ref="canvasRef" class="max-w-full" />
</div>
</template>
案例 3:文本情感分析
import * as tf from '@tensorflow/tfjs'
class SentimentAnalyzer {
private model: tf.LayersModel | null = null
private vocabulary: Map<string, number> = new Map()
private maxLength = 100
async load() {
// 加载模型和词汇表
this.model = await tf.loadLayersModel('/models/sentiment/model.json')
const vocabResponse = await fetch('/models/sentiment/vocab.json')
const vocab = await vocabResponse.json()
this.vocabulary = new Map(Object.entries(vocab))
}
private tokenize(text: string): number[] {
const words = text.toLowerCase().split(/\s+/)
const tokens = words.map(word => this.vocabulary.get(word) || 0)
// 填充或截断到固定长度
if (tokens.length < this.maxLength) {
return [...tokens, ...Array(this.maxLength - tokens.length).fill(0)]
}
return tokens.slice(0, this.maxLength)
}
async predict(text: string): Promise<{ sentiment: string; confidence: number }> {
if (!this.model) throw new Error('Model not loaded')
const tokens = this.tokenize(text)
const input = tf.tensor2d([tokens], [1, this.maxLength])
const prediction = this.model.predict(input) as tf.Tensor
const score = (await prediction.data())[0]
input.dispose()
prediction.dispose()
return {
sentiment: score > 0.5 ? 'positive' : 'negative',
confidence: score > 0.5 ? score : 1 - score
}
}
}
// 使用
const analyzer = new SentimentAnalyzer()
await analyzer.load()
const result = await analyzer.predict('This product is amazing!')
// { sentiment: 'positive', confidence: 0.92 }
性能优化
1. 使用 WebGL 后端
import * as tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-backend-webgl'
// 确保使用 WebGL
await tf.setBackend('webgl')
console.log('Backend:', tf.getBackend()) // 'webgl'
2. 内存管理
// 方式 1:手动释放
const tensor = tf.tensor([1, 2, 3])
// 使用 tensor...
tensor.dispose()
// 方式 2:tf.tidy 自动释放
const result = tf.tidy(() => {
const a = tf.tensor([1, 2, 3])
const b = tf.tensor([4, 5, 6])
return a.add(b) // 只有返回值不会被释放
})
// 监控内存
console.log('Tensors in memory:', tf.memory().numTensors)
3. 模型量化
减小模型体积,加快加载速度:
# 量化为 uint8
tensorflowjs_converter --input_format keras \
--quantize_uint8 \
model.h5 \
./quantized_model
4. 使用 WebWorker
避免阻塞主线程:
// worker.ts
import * as tf from '@tensorflow/tfjs'
let model: tf.LayersModel | null = null
self.onmessage = async (event) => {
const { type, data } = event.data
if (type === 'load') {
model = await tf.loadLayersModel(data.modelUrl)
self.postMessage({ type: 'loaded' })
}
if (type === 'predict') {
const input = tf.tensor(data.input)
const prediction = model!.predict(input) as tf.Tensor
const result = await prediction.data()
input.dispose()
prediction.dispose()
self.postMessage({ type: 'result', data: Array.from(result) })
}
}
// 主线程使用
const worker = new Worker(new URL('./worker.ts', import.meta.url))
worker.postMessage({ type: 'load', data: { modelUrl: '/model/model.json' } })
worker.onmessage = (event) => {
if (event.data.type === 'result') {
console.log('Prediction:', event.data.data)
}
}
worker.postMessage({ type: 'predict', data: { input: [...] } })
5. 模型预热
首次推理较慢,可以提前预热:
async function warmupModel(model: tf.LayersModel, inputShape: number[]) {
const dummyInput = tf.zeros(inputShape)
const warmupResult = model.predict(dummyInput) as tf.Tensor
await warmupResult.data() // 等待计算完成
dummyInput.dispose()
warmupResult.dispose()
console.log('Model warmed up')
}
常用模型推荐
| 模型 | 用途 | 包名 |
|---|---|---|
| MobileNet | 图像分类 | @tensorflow-models/mobilenet |
| COCO-SSD | 目标检测 | @tensorflow-models/coco-ssd |
| PoseNet/MoveNet | 姿态检测 | @tensorflow-models/pose-detection |
| BlazeFace | 人脸检测 | @tensorflow-models/blazeface |
| HandPose | 手势检测 | @tensorflow-models/handpose |
| Toxicity | 文本毒性检测 | @tensorflow-models/toxicity |
| Universal Sentence Encoder | 文本嵌入 | @tensorflow-models/universal-sentence-encoder |
总结
TensorFlow.js 让前端也能玩转机器学习。核心要点:
| 方面 | 建议 |
|---|---|
| 模型选择 | 优先使用官方预训练模型 |
| 性能 | WebGL 后端 + 内存管理 |
| 用户体验 | 显示加载进度,模型预热 |
| 复杂计算 | 使用 WebWorker |
记住:浏览器端 ML 适合轻量推理,复杂训练还是交给服务器。
相关文章推荐:


