• 用 TensorFlow.js 在浏览器中训练一个计算机视觉模型(手写数字分类器)



    我们在《在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)》中已经对 TensorFlow.js 的使用有了大致的了解,现在我们进一步来看如何训练一个图片数据集,并做一些可视化工作。文章代码可从《AI and Machine Learning for Coders》一书 GitHub 找到。

    在使用浏览器时,每当我们在一个 URL 上打开一个资源时,就会建立一个 HTTP 连接。我们用这个连接把命令传给服务器,然后服务器就会把结果回传。当涉及到计算机视觉时,我们通常会有大量的训练数据。例如,MNIST 和 Fashion MNIST,尽管它们已经是非常小型的图片数据集,但它们仍然包含了 70,000 张图片,这将是 70,000 个 HTTP 连接!这显然会造成大量的开销,稍后我们看如何处理这个问题。

    Building a CNN in JavaScript

    我们看在 keras 中定义的如下 针对手写数字数据集的 CNN 模型如何在 JavaScript 中定义:

    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(64, (3, 3),activation='relu', 
                               input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation=tf.nn.relu),
        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    我们分别来看看卷积层、池化层、全连接层是如何在 JavaScript 中定义的:

    我们首先将模型定义为 sequential

    model = tf.sequential();
    
    • 1

    第一个卷积层:

    model.add(tf.layers.conv2d({inputShape: [28, 28, 1],
                                kernelSize: 3,
                                filters: 64,
                                activation: 'relu'}));
    
    • 1
    • 2
    • 3
    • 4

    第一个池化层:

    model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
    
    • 1

    第一个全连接层:

    model.add(tf.layers.dense({units: 128, activation: 'relu'}));
    
    • 1

    因此完整的 JavaScript 定义为:

        model = tf.sequential();
            
        model.add(tf.layers.conv2d({inputShape: [28, 28, 1],
                                    kernelSize: 3,
                                    filters: 64,
                                    activation: 'relu'}));
            
        model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
            
        model.add(tf.layers.conv2d({kernelSize: 3,
                                    filters: 64,
                                    activation: 'relu'}));
            
        model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
            
        model.add(tf.layers.flatten());
            
        model.add(tf.layers.dense({units: 128, activation: 'relu'}));
            
        model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    编译模型时的语法:

    model.compile({optimizer: tf.train.adam(),
                   loss: 'categoricalCrossentropy',
                   metrics: ['accuracy']});
    
    • 1
    • 2
    • 3

    Using Callbacks for Visualization

    我们直接使用《在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)》的现成代码来进行演示,代码如下:

    <html>
    <head>head>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest">script>
        
        <script lang="js">
        
            async function run(){
                const csvUrl = 'iris.csv';
                const trainingData = tf.data.csv(csvUrl, {
                    columnConfigs: {
                        species: {
                            isLabel: true
                        }
                    }
                });
                
                const convertedData = trainingData.map(({xs, ys}) => {
                    const labels = [
                        ys.species == 'setosa' ? 1 : 0,
                        ys.species == 'virginica' ? 1: 0,
                        ys.species == 'versicolor' ? 1 : 0
                    ]
                    return {xs: Object.values(xs), ys: Object.values(labels)};
                }).batch(10);
            
                const numOfFeatures = (await trainingData.columnNames()).length - 1;
            
                const model = tf.sequential();
                model.add(tf.layers.dense({inputShape: [numOfFeatures],
                                           activation: "sigmoid", units: 5}));
            
                model.add(tf.layers.dense({activation: "softmax", units: 3}));
            
                model.compile({loss: "categoricalCrossentropy",
                               optimizer: tf.train.adam(0.06)});
            
                await model.fitDataset(convertedData,
                                       {epochs:100,
                                        callbacks:{
                                            onEpochEnd: async(epoch, logs) =>{
                                                console.log("Epoch: " + epoch + " Loss: " + logs.loss);
                                        }
                                    }});
                const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
                alert(model.predict(testVal));
            
            }
            
            run();
            
        script>
        
    <body>
        <h1>Iris Classifierh1>
    body>
    html>
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57

    为了使用可视化工具 tfjs-vis,我们需要在代码中添加如下的 script 标签:

    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis">script>
    
    • 1

    并用 tfvis.show 定义一个回调,来在训练时进行可视化:

    const metrics = ['loss', 'accuracy'];
                
    const container = {name: 'Model Training', 
                       styles: {height: '640px'},
                       tab: 'Training Progress'};
                
    const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    将原有代码中的回调替换成 fitCallbacks,现在我们的完整代码为:

    <html>
    <head>head>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest">script>
        
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis">script>
        
        <script lang="js">
        
            async function run(){
                const csvUrl = 'iris.csv';
                const trainingData = tf.data.csv(csvUrl, {
                    columnConfigs: {
                        species: {
                            isLabel: true
                        }
                    }
                });
                
                const convertedData = trainingData.map(({xs, ys}) => {
                    const labels = [
                        ys.species == 'setosa' ? 1 : 0,
                        ys.species == 'virginica' ? 1: 0,
                        ys.species == 'versicolor' ? 1 : 0
                    ]
                    return {xs: Object.values(xs), ys: Object.values(labels)};
                }).batch(10);
            
                const numOfFeatures = (await trainingData.columnNames()).length - 1;
            
                const model = tf.sequential();
                model.add(tf.layers.dense({inputShape: [numOfFeatures],
                                           activation: "sigmoid", units: 5}));
            
                model.add(tf.layers.dense({activation: "softmax", units: 3}));
            
                model.compile({loss: "categoricalCrossentropy",
                               optimizer: tf.train.adam(0.06)});
                
                const metrics = ['loss', 'accuracy'];
                
                const container = {name: 'Model Training', 
                                   styles: {height: '640px'},
                                   tab: 'Training Progress'};
                
                const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
            
                await model.fitDataset(convertedData,
                                       {epochs: 50,
                                        callbacks: fitCallbacks});
                
    
            
            }
            
            run();
            
        script>
        
    <body>
        <h1>Iris Classifierh1>
    body>
    html>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62

    运行之后,会有如下结果:

    在这里插入图片描述


    Training with the MNIST Dataset

    我们先如《在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)》一样新建一个项目,并在保存项目的本地路径复制一个 script.js 副本,将其命名为 data.js。下面列出的代码都会放在 data.js 中。当然,如果不想创建一个新文件,直接将下面的内容放在 index.html 中的

    <script lang="js">
    
    script>
    
    • 1
    • 2
    • 3

    也完全没有问题,但可能会使得 index.html 文件过于臃肿。

    在 TensorFlow.js 中,处理数据训练的一个特殊的方法是将所有的图像附加在一起,成为一个单一的图像,通常称为 sprite sheet,而不是逐个下载每个图像。这种技术通常在游戏开发中使用,游戏的图形被存储在一个文件中,而不是多个小文件,以提高文件存储效率。如果我们把训练用的所有图片都存储在一个文件中,我们只需要打开一个 HTTP 连接,就可以一次性下载所有图片。例如,MNIST 的 sprite sheet 如下图所示:

    在这里插入图片描述
    这幅图片的维度为 65000×784(28×28),也就是说,我们只需逐行读取该图片文件,就能得到一张张 28×28 像素的图片。

    我们可以在 JavaScript 中先将图像加载,然后定义一个画布(canvas),在从原始图像中提取出各个“线条”(行)后,在画布上画出这些“线条”。然后,画布上的字节可以被提取到一个数据集中用于训练。下面我们看具体流程:

    训练集测试集比例为 5:1

    const IMAGE_SIZE = 784;
    const NUM_CLASSES = 10;
    const NUM_DATASET_ELEMENTS = 65000;
            
    const TRAIN_TEST_RATIO = 5/6;
            
    const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
            
    const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    定义 canvas:

    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    
    • 1
    • 2
    • 3

    图片地址:

    img.src = "https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png";
    
    • 1

    一旦图像被加载,我们就可以设置一个 buffer 来容纳其中的字节。该图像是一个 PNG 文件,每个像素有 4 个字节,所以需要为 buffer 预留 65,000×768×4 个字节。我们不需要以逐个图像的方式来读取文件,而是可以分块(chunks)读取。通过指定 chunkSize,我们可以一次取五千张图片:

    img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;
                
        const datasetBytesBuffer = 
              new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
                
         const chunkSize = 5000;
         canvas.width = img.width;
         anvas.height = chunkSize;
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    下面我们通过一个 for 循环来将图片读入 buffer 中,因为图片为灰度图,所以 R\G\B 三个通道的值都是一样的,我们任意选择其中之一:

    for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
        const datasetBytesView = new Float32Array(
            datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
            IMAGE_SIZE * chunkSize);
                
        ctx.drawImage(
            img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);
                    
        const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
                    
        for (let j = 0; j < imageData.data.length / 4; j++) {
            datasetBytesView[j] = imageData.data[j * 4] / 255;
        }
    }
                
    this.datasetImages = new Float32Array(datasetBytesBuffer);
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    和图片类似,标签也是被存储在一个单独的文件中。这是一个二进制文件,对标签进行了稀疏编码。每个标签由 10 个字节表示,其中一个字节的值为 01,代表某个类别。因此,除了逐行下载和解码图像的字节外,我们还需要对标签进行解码。我们使用 arrayBuffer 将标签解码成整数数组。

    const labelsRequest = fetch("https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8");
            
    const [imgResponse, labelsResponse] = 
           await Promise.all([imgRequest, labelsRequest]);
            
    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    然后我们就可以划分训练集和测试集:

    this.trainImages = 
    	this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = 
    	this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
            
    this.trainLabels = 
    	this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels = 
    	this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    和常规流程一样,我们也可以对数据集进行分批打包(batch):

    nextBatch(batchSize, data, index) {
    	const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    	const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
    	
    	for (let i = 0; i < batchSize; i++) {
    		const idx = index();
    		
    		const image =
    			data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
    		batchImagesArray.set(image, i * IMAGE_SIZE);
    		
    		const label =
    			data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
    		batchLabelsArray.set(label, i * NUM_CLASSES);
    	}
    	
    	const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    	const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
    	
    	return {xs, labels};
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    然后,训练数据可以使用下面这个批处理函数来返回所需批次大小的且打乱顺序的训练批次:

    nextTrainBatch(batchSize) {
    	return this.nextBatch(
    		batchSize, [this.trainImages, this.trainLabels], () => {
    			this.shuffledTrainIndex =
    				(this.shuffledTrainIndex + 1) % this.trainIndices.length;
    			return this.trainIndices[this.shuffledTrainIndex];
    		});
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    测试集数据的处理方式是完全一样的,我们在下面的完整代码中给出。

    整个 data.js 文件中代码如下:我们定义了一个类 MnistData 用来封装我们刚刚定义的所有函数方法。之后在 index.html 中,我们直接从 data.js 文件中导入该类,并进行实例化 const data = new MnistData();,就可以直接调用类中的方法。

    const IMAGE_SIZE = 784;
    const NUM_CLASSES = 10;
    const NUM_DATASET_ELEMENTS = 65000;
    
    const TRAIN_TEST_RATIO = 5 / 6;
    
    const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
    const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
    
    const MNIST_IMAGES_SPRITE_PATH =
        'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
    const MNIST_LABELS_PATH =
        'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
    
    export class MnistData {
      constructor() {
        this.shuffledTrainIndex = 0;
        this.shuffledTestIndex = 0;
      }
    
      async load() {
        // Make a request for the MNIST sprited image.
        const img = new Image();
        const canvas = document.createElement('canvas');
        const ctx = canvas.getContext('2d');
        const imgRequest = new Promise((resolve, reject) => {
          img.crossOrigin = '';
          img.onload = () => {
            img.width = img.naturalWidth;
            img.height = img.naturalHeight;
    
            const datasetBytesBuffer =
                new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
    
            const chunkSize = 5000;
            canvas.width = img.width;
            canvas.height = chunkSize;
    
            for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
              const datasetBytesView = new Float32Array(
                  datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
                  IMAGE_SIZE * chunkSize);
              ctx.drawImage(
                  img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
                  chunkSize);
    
              const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
    
              for (let j = 0; j < imageData.data.length / 4; j++) {
                // All channels hold an equal value since the image is grayscale, so
                // just read the red channel.
                datasetBytesView[j] = imageData.data[j * 4] / 255;
              }
            }
            this.datasetImages = new Float32Array(datasetBytesBuffer);
    
            resolve();
          };
          img.src = MNIST_IMAGES_SPRITE_PATH;
        });
    
        const labelsRequest = fetch(MNIST_LABELS_PATH);
        const [imgResponse, labelsResponse] =
            await Promise.all([imgRequest, labelsRequest]);
    
        this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
    
        // Create shuffled indices into the train/test set for when we select a
        // random dataset element for training / validation.
        this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
        this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
    
        // Slice the the images and labels into train and test sets.
        this.trainImages =
            this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
        this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
        this.trainLabels =
            this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
        this.testLabels =
            this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
      }
    
      nextTrainBatch(batchSize) {
        return this.nextBatch(
            batchSize, [this.trainImages, this.trainLabels], () => {
              this.shuffledTrainIndex =
                  (this.shuffledTrainIndex + 1) % this.trainIndices.length;
              return this.trainIndices[this.shuffledTrainIndex];
            });
      }
    
      nextTestBatch(batchSize) {
        return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
          this.shuffledTestIndex =
              (this.shuffledTestIndex + 1) % this.testIndices.length;
          return this.testIndices[this.shuffledTestIndex];
        });
      }
    
      nextBatch(batchSize, data, index) {
        const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
        const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
    
        for (let i = 0; i < batchSize; i++) {
          const idx = index();
    
          const image =
              data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
          batchImagesArray.set(image, i * IMAGE_SIZE);
    
          const label =
              data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
          batchLabelsArray.set(label, i * NUM_CLASSES);
        }
    
        const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
        const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
    
        return {xs, labels};
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121

    下面定义的回调以及训练代码,我们之后会直接封装到 index.html 中的 train 函数当中。

    还记得我们刚刚使用的可视化回调吗?我们这里再将它定义出来:

    const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
    const container = { name: 'Model Training', styles: { height: '640px' },
    					tab: 'Training Progress' };
    const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
    
    • 1
    • 2
    • 3
    • 4

    调用函数生成训练、测试数据集:

    const [trainXs, trainYs] = tf.tidy(() => {
    	const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    	return [
    		d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
    		d.labels
    	];
    });
    
    const [testXs, testYs] = tf.tidy(() => {
    	const d = data.nextTestBatch(TEST_DATA_SIZE);
    	return [
    		d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
    		d.labels
    	];
    });
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    注意这里 tf.tidy 的使用。在 TensorFlow.js 中,它将帮助我们清理所有中间张量,除了那些函数返回的张量。在使用 TensorFlow.js 时,这对防止浏览器中的内存泄漏至关重要。

    现在万事俱备,我们就可以进行训练啦!

    return model.fit(trainXs, trainYs, {
    	batchSize: BATCH_SIZE,
    	validationData: [testXs, testYs],
    	epochs: 20,
    	shuffle: true,
    	callbacks: fitCallbacks
    });
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    大家这时可能还跑不出下面的训练过程,别着急,我们一会给出 index.html 文件的完整代码,大家到时直接运行即可。

    在这里插入图片描述


    Running Inference on Images in TensorFlow.js

    推断时我们需要一张测试图片,我们可以直接创建一个画布对象,然后让用户使用鼠标在画布上写出要判断的数字:

    rawImage = document.getElementById('canvasimg');
    ctx = canvas.getContext("2d");
    ctx.fillStyle = "black";
    ctx.fillRect(0,0,280,280);
    
    • 1
    • 2
    • 3
    • 4

    在用户通过下面的 draw 函数写好数字之后:

    function draw(e) {
    	if(e.buttons!=1) return;
    	ctx.beginPath();
    	ctx.lineWidth = 24;
    	ctx.lineCap = 'round';
    	ctx.strokeStyle = 'white';
    	ctx.moveTo(pos.x, pos.y);
    	setPosition(e);
    	ctx.lineTo(pos.x, pos.y);
    	ctx.stroke();
    	rawImage.src = canvas.toDataURL('image/png');
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    我们从画布上抓取像素,并处理成模型可以处理的输入张量:

    var raw = tf.browser.fromPixels(rawImage,1);
    
    var resized = tf.image.resizeBilinear(raw, [28,28]);
    
    var tensor = resized.expandDims(0);
    
    • 1
    • 2
    • 3
    • 4
    • 5

    之后我们就可以进行预测:

    var prediction = model.predict(tensor);
    var pIndex = tf.argMax(prediction, 1).dataSync();
    
    • 1
    • 2

    index.thml 文件完整代码:

    <html>
    <head>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest">script>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis">script>
    
    head>
    <body>
        <h1>Handwriting Classifier!h1>
        <canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;">canvas>
        <img id="canvasimg" style="position:absolute;top:10%;left:52%;width=280;height=280;display:none;">
        <input type="button" value="classify" id="sb" size="48" style="position:absolute;top:400;left:100;">
        <input type="button" value="clear" id="cb" size="23" style="position:absolute;top:400;left:180;">
        <script src="data.js" type="module">
        script>
        
        
    body>
        
        <script type="module">
            
            import {MnistData} from './data.js';
            var canvas, ctx, saveButton, clearButton;
            var pos = {x:0, y:0};
            var rawImage;
            var model;
    	
            function getModel() {
    	       model = tf.sequential();
    
    	       model.add(tf.layers.conv2d({inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'}));
    	       model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
    	       model.add(tf.layers.conv2d({filters: 16, kernelSize: 3, activation: 'relu'}));
    	       model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
    	       model.add(tf.layers.flatten());
    	       model.add(tf.layers.dense({units: 128, activation: 'relu'}));
    	       model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
    
    	       model.compile({optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']});
    
    	       return model;
            }
    
            async function train(model, data) {
    	       const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
    	       const container = { name: 'Model Training', styles: { height: '640px' }, tab: 'Training Progress'};
    	       const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
      
    	       const BATCH_SIZE = 512;
    	       const TRAIN_DATA_SIZE = 5500;
    	       const TEST_DATA_SIZE = 1000;
    
    	       const [trainXs, trainYs] = tf.tidy(() => {
    		      const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    		      return [
    			     d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
    			     d.labels
    		      ];
    	       });
    
    	       const [testXs, testYs] = tf.tidy(() => {
    		      const d = data.nextTestBatch(TEST_DATA_SIZE);
    		      return [
                    d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
    			    d.labels
    		      ];
    	       });
    
    	       return model.fit(trainXs, trainYs, {
    		      batchSize: BATCH_SIZE,
    		      validationData: [testXs, testYs],
    		      epochs: 20,
    		      shuffle: true,
    		      callbacks: fitCallbacks
    	       });
            }
    
            function setPosition(e){
    	       pos.x = e.clientX-100;
    	       pos.y = e.clientY-100;
            }
        
            function draw(e) {
    	       if(e.buttons!=1) return;
    	       ctx.beginPath();
    	       ctx.lineWidth = 24;
    	       ctx.lineCap = 'round';
    	       ctx.strokeStyle = 'white';
    	       ctx.moveTo(pos.x, pos.y);
    	       setPosition(e);
    	       ctx.lineTo(pos.x, pos.y);
    	       ctx.stroke();
    	       rawImage.src = canvas.toDataURL('image/png');
            }
        
            function erase() {
    	       ctx.fillStyle = "black";
    	       ctx.fillRect(0,0,280,280);
            }
        
            function save() {
    	       var raw = tf.browser.fromPixels(rawImage,1);
    	       var resized = tf.image.resizeBilinear(raw, [28,28]);
    	       var tensor = resized.expandDims(0);
               var prediction = model.predict(tensor);
               var pIndex = tf.argMax(prediction, 1).dataSync();
        
    	       alert(pIndex);
            }
        
            function init() {
    	       canvas = document.getElementById('canvas');
    	       rawImage = document.getElementById('canvasimg');
    	       ctx = canvas.getContext("2d");
    	       ctx.fillStyle = "black";
    	       ctx.fillRect(0,0,280,280);
    	       canvas.addEventListener("mousemove", draw);
    	       canvas.addEventListener("mousedown", setPosition);
    	       canvas.addEventListener("mouseenter", setPosition);
    	       saveButton = document.getElementById('sb');
    	       saveButton.addEventListener("click", save);
    	       clearButton = document.getElementById('cb');
    	       clearButton.addEventListener("click", erase);
            }
    
    
            async function run() {  
    	       const data = new MnistData();
    	       await data.load();
    	       const model = getModel();
    	       tfvis.show.modelSummary({name: 'Model Architecture'}, model);
    	       await train(model, data);
    	       init();
    	       alert("Training is done, try classifying your handwriting!");
            }
            
            run();
        
        script>
        
        
    
    html>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142

    当运行之后,我们等模型训练完毕,右侧就会出现下面的界面:
    在这里插入图片描述

    我们直接用鼠标在黑色画布上随便写一个数字,然后点击 classify:

    在这里插入图片描述
    在这里插入图片描述


    References

    AI and Machine Learning for Coders by Laurence Moroney.

  • 相关阅读:
    OSPF协议详解及快速学会(OSPF协议与RIP协议对比,区域划分,数据包类型,OSPF的状态机,工作过程,基础配置,扩展配置,条件匹配)
    YoloV8改进策略:聚焦线性注意力重构YoloV8
    2023年9月 少儿编程 中国电子学会图形化编程等级考试Scratch编程一级真题解析(选择题)
    MegEngine Inference 卷积优化之 Im2col 和 winograd 优化
    HTML5期末考核大作业,电影网站——橙色国外电影 web期末作业设计网页
    关闭jupyter notebook报错
    阿里云K8S部署Go+Vue项目
    CentOS8结束生命周期后如何切换镜像源
    PHP基于thinkphp的网上图书管理系统#毕业设计
    Linux系统中如何查看磁盘情况
  • 原文地址:https://blog.csdn.net/myDarling_/article/details/128159555