TensorFlow.js 前端机器学习:在浏览器中运行 AI 模型

HTMLPAGE 团队
16 分钟阅读

系统介绍如何使用 TensorFlow.js 在浏览器中进行机器学习,涵盖模型加载、预测推理和性能优化

#TensorFlow.js #机器学习 #浏览器 AI #前端 ML

为什么在浏览器跑机器学习

传统 ML 推理流程:用户上传数据 → 发送服务器 → 模型计算 → 返回结果

浏览器端 ML:用户数据不离开浏览器 → 本地模型计算 → 即时结果

优势

方面服务端推理浏览器端推理
隐私数据上传服务器数据不离开设备
延迟网络往返即时响应
成本服务器算力用户设备算力
离线需要网络支持离线

适用场景

  • 图像分类/目标检测
  • 姿态识别
  • 文本情感分析
  • 实时滤镜/特效
  • 语音识别(简单场景)

快速开始

安装

# npm
npm install @tensorflow/tfjs

# CDN(开发测试)
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

基础用法

import * as tf from '@tensorflow/tfjs'

// 创建张量
const tensor = tf.tensor([1, 2, 3, 4])
console.log(tensor.shape) // [4]

// 基本运算
const a = tf.tensor([1, 2, 3])
const b = tf.tensor([4, 5, 6])
const sum = a.add(b)
sum.print() // [5, 7, 9]

// 矩阵运算
const matrix = tf.tensor2d([[1, 2], [3, 4]])
const transposed = matrix.transpose()
transposed.print()

// 释放内存(重要!)
tensor.dispose()

加载预训练模型

官方预训练模型

TensorFlow.js 提供了多个开箱即用的模型:

// 图像分类
import * as mobilenet from '@tensorflow-models/mobilenet'

const model = await mobilenet.load()
const predictions = await model.classify(imageElement)
// [{ className: 'cat', probability: 0.95 }, ...]

// 目标检测
import * as cocoSsd from '@tensorflow-models/coco-ssd'

const detector = await cocoSsd.load()
const objects = await detector.detect(imageElement)
// [{ class: 'person', bbox: [x, y, width, height], score: 0.89 }, ...]

// 姿态识别
import * as poseDetection from '@tensorflow-models/pose-detection'

const detector = await poseDetection.createDetector(
  poseDetection.SupportedModels.MoveNet
)
const poses = await detector.estimatePoses(videoElement)

加载自定义模型

// 从 URL 加载
const model = await tf.loadLayersModel('https://example.com/model/model.json')

// 从本地文件加载
const model = await tf.loadLayersModel('localstorage://my-model')

// 从 IndexedDB 加载
const model = await tf.loadLayersModel('indexeddb://my-model')

模型格式转换

将 Python 训练的模型转换为 TensorFlow.js 格式:

# 安装转换工具
pip install tensorflowjs

# Keras 模型转换
tensorflowjs_converter --input_format keras \
  model.h5 \
  ./web_model

# SavedModel 转换
tensorflowjs_converter --input_format tf_saved_model \
  ./saved_model \
  ./web_model

实战案例

案例 1:图像分类组件

<script setup lang="ts">
import * as mobilenet from '@tensorflow-models/mobilenet'

const imageRef = ref<HTMLImageElement>()
const predictions = ref<Prediction[]>([])
const isLoading = ref(false)
const model = ref<mobilenet.MobileNet>()

// 加载模型
onMounted(async () => {
  isLoading.value = true
  model.value = await mobilenet.load()
  isLoading.value = false
})

// 处理图片上传
async function handleImageUpload(event: Event) {
  const file = (event.target as HTMLInputElement).files?.[0]
  if (!file) return

  const url = URL.createObjectURL(file)
  imageRef.value!.src = url

  // 等待图片加载
  await new Promise(resolve => {
    imageRef.value!.onload = resolve
  })

  // 分类
  isLoading.value = true
  predictions.value = await model.value!.classify(imageRef.value!)
  isLoading.value = false
}
</script>

<template>
  <div class="p-4">
    <input 
      type="file" 
      accept="image/*" 
      @change="handleImageUpload"
      :disabled="!model"
    />
    
    <div class="mt-4">
      <img ref="imageRef" class="max-w-md" />
    </div>

    <div v-if="isLoading" class="mt-4">分析中...</div>

    <div v-if="predictions.length" class="mt-4 space-y-2">
      <div 
        v-for="pred in predictions" 
        :key="pred.className"
        class="flex items-center gap-2"
      >
        <span class="font-medium">{{ pred.className }}</span>
        <div class="flex-1 bg-gray-200 rounded-full h-2">
          <div 
            class="bg-blue-500 h-2 rounded-full" 
            :style="{ width: `${pred.probability * 100}%` }"
          ></div>
        </div>
        <span>{{ (pred.probability * 100).toFixed(1) }}%</span>
      </div>
    </div>
  </div>
</template>

案例 2:实时姿态检测

<script setup lang="ts">
import * as poseDetection from '@tensorflow-models/pose-detection'

const videoRef = ref<HTMLVideoElement>()
const canvasRef = ref<HTMLCanvasElement>()
const detector = ref<poseDetection.PoseDetector>()
let animationId: number

onMounted(async () => {
  // 加载模型
  detector.value = await poseDetection.createDetector(
    poseDetection.SupportedModels.MoveNet,
    { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }
  )

  // 获取摄像头
  const stream = await navigator.mediaDevices.getUserMedia({ video: true })
  videoRef.value!.srcObject = stream
  
  videoRef.value!.onloadedmetadata = () => {
    videoRef.value!.play()
    detectPose()
  }
})

async function detectPose() {
  if (!detector.value || !videoRef.value || !canvasRef.value) return

  const poses = await detector.value.estimatePoses(videoRef.value)
  drawPoses(poses)

  animationId = requestAnimationFrame(detectPose)
}

function drawPoses(poses: poseDetection.Pose[]) {
  const ctx = canvasRef.value!.getContext('2d')!
  const { width, height } = videoRef.value!

  canvasRef.value!.width = width
  canvasRef.value!.height = height

  // 绘制视频帧
  ctx.drawImage(videoRef.value!, 0, 0)

  // 绘制关键点
  for (const pose of poses) {
    for (const keypoint of pose.keypoints) {
      if (keypoint.score && keypoint.score > 0.5) {
        ctx.beginPath()
        ctx.arc(keypoint.x, keypoint.y, 5, 0, 2 * Math.PI)
        ctx.fillStyle = 'red'
        ctx.fill()
      }
    }

    // 绘制骨骼连线
    const edges = poseDetection.util.getAdjacentPairs(
      poseDetection.SupportedModels.MoveNet
    )
    
    for (const [i, j] of edges) {
      const kp1 = pose.keypoints[i]
      const kp2 = pose.keypoints[j]
      
      if (kp1.score! > 0.5 && kp2.score! > 0.5) {
        ctx.beginPath()
        ctx.moveTo(kp1.x, kp1.y)
        ctx.lineTo(kp2.x, kp2.y)
        ctx.strokeStyle = 'green'
        ctx.lineWidth = 2
        ctx.stroke()
      }
    }
  }
}

onUnmounted(() => {
  cancelAnimationFrame(animationId)
})
</script>

<template>
  <div class="relative">
    <video ref="videoRef" class="hidden" />
    <canvas ref="canvasRef" class="max-w-full" />
  </div>
</template>

案例 3:文本情感分析

import * as tf from '@tensorflow/tfjs'

class SentimentAnalyzer {
  private model: tf.LayersModel | null = null
  private vocabulary: Map<string, number> = new Map()
  private maxLength = 100

  async load() {
    // 加载模型和词汇表
    this.model = await tf.loadLayersModel('/models/sentiment/model.json')
    const vocabResponse = await fetch('/models/sentiment/vocab.json')
    const vocab = await vocabResponse.json()
    this.vocabulary = new Map(Object.entries(vocab))
  }

  private tokenize(text: string): number[] {
    const words = text.toLowerCase().split(/\s+/)
    const tokens = words.map(word => this.vocabulary.get(word) || 0)
    
    // 填充或截断到固定长度
    if (tokens.length < this.maxLength) {
      return [...tokens, ...Array(this.maxLength - tokens.length).fill(0)]
    }
    return tokens.slice(0, this.maxLength)
  }

  async predict(text: string): Promise<{ sentiment: string; confidence: number }> {
    if (!this.model) throw new Error('Model not loaded')

    const tokens = this.tokenize(text)
    const input = tf.tensor2d([tokens], [1, this.maxLength])

    const prediction = this.model.predict(input) as tf.Tensor
    const score = (await prediction.data())[0]

    input.dispose()
    prediction.dispose()

    return {
      sentiment: score > 0.5 ? 'positive' : 'negative',
      confidence: score > 0.5 ? score : 1 - score
    }
  }
}

// 使用
const analyzer = new SentimentAnalyzer()
await analyzer.load()

const result = await analyzer.predict('This product is amazing!')
// { sentiment: 'positive', confidence: 0.92 }

性能优化

1. 使用 WebGL 后端

import * as tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-backend-webgl'

// 确保使用 WebGL
await tf.setBackend('webgl')
console.log('Backend:', tf.getBackend()) // 'webgl'

2. 内存管理

// 方式 1:手动释放
const tensor = tf.tensor([1, 2, 3])
// 使用 tensor...
tensor.dispose()

// 方式 2:tf.tidy 自动释放
const result = tf.tidy(() => {
  const a = tf.tensor([1, 2, 3])
  const b = tf.tensor([4, 5, 6])
  return a.add(b) // 只有返回值不会被释放
})

// 监控内存
console.log('Tensors in memory:', tf.memory().numTensors)

3. 模型量化

减小模型体积,加快加载速度:

# 量化为 uint8
tensorflowjs_converter --input_format keras \
  --quantize_uint8 \
  model.h5 \
  ./quantized_model

4. 使用 WebWorker

避免阻塞主线程:

// worker.ts
import * as tf from '@tensorflow/tfjs'

let model: tf.LayersModel | null = null

self.onmessage = async (event) => {
  const { type, data } = event.data

  if (type === 'load') {
    model = await tf.loadLayersModel(data.modelUrl)
    self.postMessage({ type: 'loaded' })
  }

  if (type === 'predict') {
    const input = tf.tensor(data.input)
    const prediction = model!.predict(input) as tf.Tensor
    const result = await prediction.data()
    
    input.dispose()
    prediction.dispose()
    
    self.postMessage({ type: 'result', data: Array.from(result) })
  }
}

// 主线程使用
const worker = new Worker(new URL('./worker.ts', import.meta.url))

worker.postMessage({ type: 'load', data: { modelUrl: '/model/model.json' } })

worker.onmessage = (event) => {
  if (event.data.type === 'result') {
    console.log('Prediction:', event.data.data)
  }
}

worker.postMessage({ type: 'predict', data: { input: [...] } })

5. 模型预热

首次推理较慢,可以提前预热:

async function warmupModel(model: tf.LayersModel, inputShape: number[]) {
  const dummyInput = tf.zeros(inputShape)
  const warmupResult = model.predict(dummyInput) as tf.Tensor
  
  await warmupResult.data() // 等待计算完成
  
  dummyInput.dispose()
  warmupResult.dispose()
  
  console.log('Model warmed up')
}

常用模型推荐

模型用途包名
MobileNet图像分类@tensorflow-models/mobilenet
COCO-SSD目标检测@tensorflow-models/coco-ssd
PoseNet/MoveNet姿态检测@tensorflow-models/pose-detection
BlazeFace人脸检测@tensorflow-models/blazeface
HandPose手势检测@tensorflow-models/handpose
Toxicity文本毒性检测@tensorflow-models/toxicity
Universal Sentence Encoder文本嵌入@tensorflow-models/universal-sentence-encoder

总结

TensorFlow.js 让前端也能玩转机器学习。核心要点:

方面建议
模型选择优先使用官方预训练模型
性能WebGL 后端 + 内存管理
用户体验显示加载进度,模型预热
复杂计算使用 WebWorker

记住:浏览器端 ML 适合轻量推理,复杂训练还是交给服务器。


相关文章推荐: