• TensorFlow - 自定义 callback


    Last updated: 2022-09-20, 15:05


    简介

    callback 是一个强大的工具,可以在训练、评估和推断期间自定义 Keras 模型的行为。例如,tf.keras.callbacks.TensorBoard 可以使用 TensorBoard 可视化训练进度和结果,tf.keras.callbacks.ModelCheckpoint 可以在训练时定期保存模型。

    下面介绍介绍什么是 callback,它可以做什么,以及如何创建自己的 callback。

    设置

    import tensorflow as tf
    from tensorflow import keras
    

    Keras callbacks 概述

    所有 callback 继承自 keras.callbacks.Callback 类,并覆盖在训练、测试和推断等不同阶段自动调用的一组方法。

    可以将 callback 列表(callbacks 关键字参数)传递给以下模型方法来使用:

    callback 方法概述

    Global 方法

    • on_(train|test|predict)_begin(self, logs=None)

    fit/evaluate/predict 前调用。

    • on_(train|test|predict)_end(self, logs=None)

    fit/evaluate/predict 结束时调用。

    Batch-level 方法

    • on_(train|test|predict)_batch_begin(self, batch, logs=None)

    在训练、测试、推理期间处理 batch 前调用。

    • on_(train|test|predict)_batch_end(self, batch, logs=None)

    在训练、测试、推理期间处理 batch 结束调用。logs 是包含 metrics 结果的 dict。

    Epoch-level 方法(仅训练)

    • on_epoch_begin(self, epoch, logs=None)

    在训练的 epoch 开始时调用。

    • on_epoch_end(self, epoch, logs=None)

    在训练的 epoch 结束时调用。

    一个简单示例

    下面通过一个具体的例子来展示 callback 的功能。首先定义一个 Keras Sequential 模型:

    def get_model():
        model = keras.Sequential()
        model.add(keras.layers.Dense(1, input_dim=784))
        model.compile(
            optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
            loss="mean_squared_error",
            metrics=["mean_absolute_error"],
        )
        return model
    

    然后,加载 MNIST 数据集:

    # 加载 MNIST 数据集,并预处理
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
    x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
    
    # 只取前 1000 个样本
    x_train = x_train[:1000]
    y_train = y_train[:1000]
    x_test = x_test[:1000]
    y_test = y_test[:1000]
    

    现在开始自定义一个 callback,在下面的各个阶段输出信息:

    • fit/evaluate/predict 开始和结束时
    • epoch 开始和结束
    • 每个训练 batch 的开始和结束
    • 每个评估 batch 的开始和结束
    • 每个推断 batch 的开始和结束
    class CustomCallback(keras.callbacks.Callback):
        def on_train_begin(self, logs=None):
            keys = list(logs.keys())
            print("Starting training; got log keys: {}".format(keys))
    
        def on_train_end(self, logs=None):
            keys = list(logs.keys())
            print("Stop training; got log keys: {}".format(keys))
    
        def on_epoch_begin(self, epoch, logs=None):
            keys = list(logs.keys())
            print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
    
        def on_epoch_end(self, epoch, logs=None):
            keys = list(logs.keys())
            print("End epoch {} of training; got log keys: {}".format(epoch, keys))
    
        def on_test_begin(self, logs=None):
            keys = list(logs.keys())
            print("Start testing; got log keys: {}".format(keys))
    
        def on_test_end(self, logs=None):
            keys = list(logs.keys())
            print("Stop testing; got log keys: {}".format(keys))
    
        def on_predict_begin(self, logs=None):
            keys = list(logs.keys())
            print("Start predicting; got log keys: {}".format(keys))
    
        def on_predict_end(self, logs=None):
            keys = list(logs.keys())
            print("Stop predicting; got log keys: {}".format(keys))
    
        def on_train_batch_begin(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
    
        def on_train_batch_end(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
    
        def on_test_batch_begin(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
    
        def on_test_batch_end(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
    
        def on_predict_batch_begin(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
    
        def on_predict_batch_end(self, batch, logs=None):
            keys = list(logs.keys())
            print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
    

    试用这个 callback:

    model = get_model()
    model.fit(
        x_train,
        y_train,
        batch_size=128,
        epochs=1,
        verbose=0,
        validation_split=0.5,
        callbacks=[CustomCallback()],
    )
    
    res = model.evaluate(
        x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
    )
    
    res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
    
    Starting training; got log keys: []
    Start epoch 0 of training; got log keys: []
    ...Training: start of batch 0; got log keys: []
    ...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
    ...Training: start of batch 1; got log keys: []
    ...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
    ...Training: start of batch 2; got log keys: []
    ...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
    ...Training: start of batch 3; got log keys: []
    ...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
    Start testing; got log keys: []
    ...Evaluating: start of batch 0; got log keys: []
    ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 1; got log keys: []
    ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 2; got log keys: []
    ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 3; got log keys: []
    ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
    Stop testing; got log keys: ['loss', 'mean_absolute_error']
    End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
    Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
    Start testing; got log keys: []
    ...Evaluating: start of batch 0; got log keys: []
    ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 1; got log keys: []
    ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 2; got log keys: []
    ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 3; got log keys: []
    ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 4; got log keys: []
    ...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 5; got log keys: []
    ...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 6; got log keys: []
    ...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error']
    ...Evaluating: start of batch 7; got log keys: []
    ...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error']
    Stop testing; got log keys: ['loss', 'mean_absolute_error']
    Start predicting; got log keys: []
    ...Predicting: start of batch 0; got log keys: []
    ...Predicting: end of batch 0; got log keys: ['outputs']
    1/8 [==>...........................] - ETA: 0s...Predicting: start of batch 1; got log keys: []
    ...Predicting: end of batch 1; got log keys: ['outputs']
    ...Predicting: start of batch 2; got log keys: []
    ...Predicting: end of batch 2; got log keys: ['outputs']
    ...Predicting: start of batch 3; got log keys: []
    ...Predicting: end of batch 3; got log keys: ['outputs']
    ...Predicting: start of batch 4; got log keys: []
    ...Predicting: end of batch 4; got log keys: ['outputs']
    ...Predicting: start of batch 5; got log keys: []
    ...Predicting: end of batch 5; got log keys: ['outputs']
    WARNING:tensorflow:Callback method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0004s vs `on_predict_batch_end` time: 0.0005s). Check your callbacks.
    ...Predicting: start of batch 6; got log keys: []
    ...Predicting: end of batch 6; got log keys: ['outputs']
    ...Predicting: start of batch 7; got log keys: []
    ...Predicting: end of batch 7; got log keys: ['outputs']
    Stop predicting; got log keys: []
    8/8 [==============================] - 0s 1ms/step
    

    使用 logs dict

    logs dict 包含 loss 值,以及 batch 或 epoch 结束时的所有 metrics。对上面的例子,则包含 loss 值和 mean_absolute_error。

    class LossAndErrorPrintingCallback(keras.callbacks.Callback):
        def on_train_batch_end(self, batch, logs=None):
            print(
                "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
            )
    
        def on_test_batch_end(self, batch, logs=None):
            print(
                "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
            )
    
        def on_epoch_end(self, epoch, logs=None):
            print(
                "The average loss for epoch {} is {:7.2f} "
                "and mean absolute error is {:7.2f}.".format(
                    epoch, logs["loss"], logs["mean_absolute_error"]
                )
            )
    
    
    model = get_model()
    model.fit(
        x_train,
        y_train,
        batch_size=128,
        epochs=2,
        verbose=0,
        callbacks=[LossAndErrorPrintingCallback()],
    )
    
    res = model.evaluate(
        x_test,
        y_test,
        batch_size=128,
        verbose=0,
        callbacks=[LossAndErrorPrintingCallback()],
    )
    
    Up to batch 0, the average loss is   31.56.
    Up to batch 1, the average loss is  453.13.
    Up to batch 2, the average loss is  309.70.
    Up to batch 3, the average loss is  234.35.
    Up to batch 4, the average loss is  189.27.
    Up to batch 5, the average loss is  158.79.
    Up to batch 6, the average loss is  137.01.
    Up to batch 7, the average loss is  123.39.
    The average loss for epoch 0 is  123.39 and mean absolute error is    6.04.
    Up to batch 0, the average loss is    5.41.
    Up to batch 1, the average loss is    5.40.
    Up to batch 2, the average loss is    5.18.
    Up to batch 3, the average loss is    5.04.
    Up to batch 4, the average loss is    4.88.
    Up to batch 5, the average loss is    4.73.
    Up to batch 6, the average loss is    4.60.
    Up to batch 7, the average loss is    4.75.
    The average loss for epoch 1 is    4.75 and mean absolute error is    1.76.
    Up to batch 0, the average loss is    5.27.
    Up to batch 1, the average loss is    4.90.
    Up to batch 2, the average loss is    4.91.
    Up to batch 3, the average loss is    4.80.
    Up to batch 4, the average loss is    4.90.
    Up to batch 5, the average loss is    4.97.
    Up to batch 6, the average loss is    4.90.
    Up to batch 7, the average loss is    4.85.
    

    self.model 属性

    Callback 的各个方法除了从 log 获得信息,还可以通过 self.model 访问当前模型。

    在 callback 中使用 self.model 可以做很多事情,比如:

    • 设置 self.model.stop_training = True 中断训练
    • 修改 optimizer (self.model.optimizer) 的超参数,例如 self.model.optimizer.learning_rate
    • 按周期保存模型
    • 在每个 epoch 结尾记录 model.predict 在几个测试样本上的输出,以便在训练期间检查性能。
    • 在每个 epoch 结尾提取中间特征的可视化,以监督模型随着时间的变化。

    下面通过几个小例子看看。

    Keras callback 应用示例

    在 loss 最小时终止训练

    创建一个 Callback,通过设置 self.model.stop_training (boolean) 属性,实现在模型 loss 最小时终止训练。还提供一个 patience 参数,指定在达到局部最小值后,等待多少个 epoch 再终止训练。

    对该功能 tf.keras.callbacks.EarlyStopping 提供了完整和通用实现。下面提供了一个简单实现:

    import numpy as np
    
    
    class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
        """在 loss 达到最小时(即 loss 停止减少时)终止训练。
    
      Arguments:
          patience: 达到最小 loss 后等待的 epoch 数,在这些 epoch 后如果没有改进,终止训练。
      """
        def __init__(self, patience=0):
            super(EarlyStoppingAtMinLoss, self).__init__()
            self.patience = patience
            # best_weights 用来保存最佳权重
            self.best_weights = None
    
        def on_train_begin(self, logs=None):
            # The number of epoch it has waited when loss is no longer minimum.
            self.wait = 0
            # 训练在该 epoch 后终止
            self.stopped_epoch = 0
            # 初始化最佳 loss
            self.best = np.Inf
    
        def on_epoch_end(self, epoch, logs=None):
            current = logs.get("loss")
            if np.less(current, self.best): # 当前 loss 更小
                self.best = current
                self.wait = 0
                # 当前结果更好,所以更新最佳 weights
                self.best_weights = self.model.get_weights()
            else:
                self.wait += 1
                if self.wait >= self.patience:
                    self.stopped_epoch = epoch
                    self.model.stop_training = True
                    print("Restoring model weights from the end of the best epoch.")
                    self.model.set_weights(self.best_weights)
    
        def on_train_end(self, logs=None):
            if self.stopped_epoch > 0:
                print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
    
    
    model = get_model()
    model.fit(
        x_train,
        y_train,
        batch_size=64,
        steps_per_epoch=5,
        epochs=30,
        verbose=0,
        callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
    )
    
    Up to batch 0, the average loss is   26.43.
    Up to batch 1, the average loss is  456.55.
    Up to batch 2, the average loss is  310.41.
    Up to batch 3, the average loss is  235.11.
    Up to batch 4, the average loss is  189.38.
    The average loss for epoch 0 is  189.38 and mean absolute error is    8.06.
    Up to batch 0, the average loss is    6.38.
    Up to batch 1, the average loss is    7.21.
    Up to batch 2, the average loss is    6.38.
    Up to batch 3, the average loss is    6.18.
    Up to batch 4, the average loss is    5.99.
    The average loss for epoch 1 is    5.99 and mean absolute error is    1.97.
    Up to batch 0, the average loss is    5.09.
    Up to batch 1, the average loss is    5.96.
    Up to batch 2, the average loss is    5.45.
    Up to batch 3, the average loss is    5.63.
    Up to batch 4, the average loss is    6.12.
    The average loss for epoch 2 is    6.12 and mean absolute error is    2.05.
    Restoring model weights from the end of the best epoch.
    Epoch 00003: early stopping
    
    
    

    学习率调整

    下面使用自定义 callback 在训练过程中动态调整学习率。

    tf.keras.callbacks.LearningRateScheduler 是一个更为通用的实现。

    class CustomLearningRateScheduler(keras.callbacks.Callback):
        """Learning rate scheduler which sets the learning rate according to schedule.
    
      Arguments:
          schedule: LR 调度函数,以 epoch 索引(0-based)和当前 LR 为输入,输出新的 LR
      """
    
        def __init__(self, schedule):
            super(CustomLearningRateScheduler, self).__init__()
            self.schedule = schedule
    
        def on_epoch_begin(self, epoch, logs=None):
            if not hasattr(self.model.optimizer, "lr"):
                raise ValueError('Optimizer must have a "lr" attribute.')
            # 从模型的 optimizer 查询当前 LR
            lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
            # 调用 schedule 函数获得计划 LR
            scheduled_lr = self.schedule(epoch, lr)
            # 在 epoch 开始前设置 optimizer 的 LR
            tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
            print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
    
    
    LR_SCHEDULE = [
        # (epoch to start, learning rate) tuples
        (3, 0.05),
        (6, 0.01),
        (9, 0.005),
        (12, 0.001),
    ]
    
    
    def lr_schedule(epoch, lr):
        """Helper function to retrieve the scheduled learning rate based on epoch."""
        if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
            return lr
        for i in range(len(LR_SCHEDULE)):
            if epoch == LR_SCHEDULE[i][0]:
                return LR_SCHEDULE[i][1]
        return lr
    
    
    model = get_model()
    model.fit(
        x_train,
        y_train,
        batch_size=64,
        steps_per_epoch=5,
        epochs=15,
        verbose=0,
        callbacks=[
            LossAndErrorPrintingCallback(),
            CustomLearningRateScheduler(lr_schedule),
        ],
    )
    
    Epoch 00000: Learning rate is 0.1000.
    Up to batch 0, the average loss is   22.56.
    Up to batch 1, the average loss is  461.56.
    Up to batch 2, the average loss is  316.79.
    Up to batch 3, the average loss is  240.73.
    Up to batch 4, the average loss is  193.91.
    The average loss for epoch 0 is  193.91 and mean absolute error is    8.27.
    
    Epoch 00001: Learning rate is 0.1000.
    Up to batch 0, the average loss is    7.08.
    Up to batch 1, the average loss is    6.19.
    Up to batch 2, the average loss is    5.66.
    Up to batch 3, the average loss is    5.80.
    Up to batch 4, the average loss is    5.68.
    The average loss for epoch 1 is    5.68 and mean absolute error is    1.98.
    
    Epoch 00002: Learning rate is 0.1000.
    Up to batch 0, the average loss is    4.53.
    Up to batch 1, the average loss is    4.83.
    Up to batch 2, the average loss is    5.58.
    Up to batch 3, the average loss is    5.21.
    Up to batch 4, the average loss is    4.94.
    The average loss for epoch 2 is    4.94 and mean absolute error is    1.74.
    
    Epoch 00003: Learning rate is 0.0500.
    Up to batch 0, the average loss is    4.56.
    Up to batch 1, the average loss is    3.98.
    Up to batch 2, the average loss is    3.46.
    Up to batch 3, the average loss is    3.46.
    Up to batch 4, the average loss is    3.60.
    The average loss for epoch 3 is    3.60 and mean absolute error is    1.52.
    
    Epoch 00004: Learning rate is 0.0500.
    Up to batch 0, the average loss is    4.21.
    Up to batch 1, the average loss is    3.87.
    Up to batch 2, the average loss is    3.64.
    Up to batch 3, the average loss is    3.88.
    Up to batch 4, the average loss is    3.95.
    The average loss for epoch 4 is    3.95 and mean absolute error is    1.59.
    
    Epoch 00005: Learning rate is 0.0500.
    Up to batch 0, the average loss is    3.51.
    Up to batch 1, the average loss is    3.24.
    Up to batch 2, the average loss is    3.24.
    Up to batch 3, the average loss is    3.71.
    Up to batch 4, the average loss is    3.92.
    The average loss for epoch 5 is    3.92 and mean absolute error is    1.54.
    
    Epoch 00006: Learning rate is 0.0100.
    Up to batch 0, the average loss is    4.53.
    Up to batch 1, the average loss is    3.79.
    Up to batch 2, the average loss is    3.44.
    Up to batch 3, the average loss is    3.76.
    Up to batch 4, the average loss is    3.53.
    The average loss for epoch 6 is    3.53 and mean absolute error is    1.43.
    
    Epoch 00007: Learning rate is 0.0100.
    Up to batch 0, the average loss is    2.59.
    Up to batch 1, the average loss is    2.86.
    Up to batch 2, the average loss is    3.10.
    Up to batch 3, the average loss is    3.18.
    Up to batch 4, the average loss is    3.02.
    The average loss for epoch 7 is    3.02 and mean absolute error is    1.35.
    
    Epoch 00008: Learning rate is 0.0100.
    Up to batch 0, the average loss is    3.65.
    Up to batch 1, the average loss is    3.06.
    Up to batch 2, the average loss is    3.39.
    Up to batch 3, the average loss is    3.73.
    Up to batch 4, the average loss is    3.75.
    The average loss for epoch 8 is    3.75 and mean absolute error is    1.52.
    
    Epoch 00009: Learning rate is 0.0050.
    Up to batch 0, the average loss is    3.25.
    Up to batch 1, the average loss is    3.25.
    Up to batch 2, the average loss is    3.08.
    Up to batch 3, the average loss is    3.10.
    Up to batch 4, the average loss is    3.13.
    The average loss for epoch 9 is    3.13 and mean absolute error is    1.37.
    
    Epoch 00010: Learning rate is 0.0050.
    Up to batch 0, the average loss is    3.79.
    Up to batch 1, the average loss is    3.78.
    Up to batch 2, the average loss is    3.80.
    Up to batch 3, the average loss is    4.04.
    Up to batch 4, the average loss is    3.75.
    The average loss for epoch 10 is    3.75 and mean absolute error is    1.49.
    
    Epoch 00011: Learning rate is 0.0050.
    Up to batch 0, the average loss is    3.70.
    Up to batch 1, the average loss is    3.28.
    Up to batch 2, the average loss is    3.46.
    Up to batch 3, the average loss is    3.21.
    Up to batch 4, the average loss is    3.10.
    The average loss for epoch 11 is    3.10 and mean absolute error is    1.39.
    
    Epoch 00012: Learning rate is 0.0010.
    Up to batch 0, the average loss is    2.31.
    Up to batch 1, the average loss is    2.39.
    Up to batch 2, the average loss is    2.69.
    Up to batch 3, the average loss is    2.76.
    Up to batch 4, the average loss is    2.87.
    The average loss for epoch 12 is    2.87 and mean absolute error is    1.32.
    
    Epoch 00013: Learning rate is 0.0010.
    Up to batch 0, the average loss is    2.27.
    Up to batch 1, the average loss is    3.57.
    Up to batch 2, the average loss is    3.77.
    Up to batch 3, the average loss is    3.76.
    Up to batch 4, the average loss is    3.64.
    The average loss for epoch 13 is    3.64 and mean absolute error is    1.44.
    
    Epoch 00014: Learning rate is 0.0010.
    Up to batch 0, the average loss is    3.00.
    Up to batch 1, the average loss is    3.14.
    Up to batch 2, the average loss is    3.01.
    Up to batch 3, the average loss is    3.16.
    Up to batch 4, the average loss is    3.25.
    The average loss for epoch 14 is    3.25 and mean absolute error is    1.41.
    
    
    

    参考

    • https://www.tensorflow.org/guide/keras/custom_callback
    • https://keras.io/guides/writing_your_own_callbacks/
  • 相关阅读:
    java-net-php-python-jsp小家蔬菜展示平台计算机毕业设计程序
    【网络安全】国家专利局专利办理系统存在信息泄漏风险
    Java 实现前端数据的导出操作
    【开源】SpringBoot框架开发用户画像活动推荐系统
    DHorse(K8S的CICD平台)的实现原理
    【算法入门】设计模板队列|循环队列
    Influence on Social media(素论+思维)
    在word中删除endnote参考文献之间的空行
    一次不规范HTTP请求引发的nginx响应400问题分析与解决
    百度飞桨(PaddlePaddle) - PP-OCRv3 文字检测识别系统 Paddle Inference 模型推理
  • 原文地址:https://blog.csdn.net/Marryvivien/article/details/126953717