• keras:callback 专属定制!来实现一个自己的 callback 吧!


    我们在一个深度模型的训练中经常会用到回调函数来对训练过程进行监测,使得训练过程更加智能化。

    例如,我们经常使用的早停机制:

    from tensorflow.keras.callbacks import EarlyStopping
    
    early_stop = EarlyStopping(monitor='val_loss', 
                               mode='min', 
                               patience=10, 
                               restore_best_weights=True, 
                               verbose=1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    通过监测验证误差的变化趋势,我们可以在验证误差不再增长的时候提前结束训练。

    另一个与 EarlyStopping 常常配合使用的是 ReduceLROnPlateau,当指定的训练误差或者验证误差在指定的轮次以内不再增长的时候,我们将学习率根据设置的衰减系数 factor 自动降低:

    from keras.callbacks import ReduceLROnPlateau
    
    learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', 
                                                patience=3, 
                                                verbose=1, 
                                                factor=0.5, 
                                                min_lr=0.00001)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    keras 的回调 API 包含许多不同功能用途的回调函数,通常这些回调就可以满足我们的需求了。但如果我们想要更加精细的控制训练过程,可能需要写一个自己的回调。

    我们接下来就实现一个,可以在训练过程中自由控制训练轮次和学习率的回调。这个回调的功能主要是:

    • 在指定轮次之后询问使用者,是否继续训练,如果继续训练,键入继续训练的轮次,并选择保持或者改变当前学习率
    • 如果验证误差增加,则自动调整学习率,且模型加载当前最优的权重
    • 训练结束后,直接让模型加载最优权重

    我们需要定义一个类,这个类继承 keras.callbacks.Callback,然后做一些初始化:

    class My_ASK(keras.callbacks.Callback):
        def __init__(self, model, epochs, ask_epoch, dwell=True, factor=.4):
            super(My_ASK, self).__init__()
            self.model = model
            
            """
    			模型在训练 ask_epoch 之后,会让使用者选择是暂停训练还是继续训练,
    			如果继续训练,则直接输入一个整数,表明继续训练的轮次,且会给我们
    			修改学习率的机会
    		"""
            
            self.ask_epoch = ask_epoch
            
            self.epochs = epochs
            self.ask = True # 将 ask 设为 True 才会有上面 ask_epoch 描述的询问
            self.lowest_vloss = np.inf
            self.lowest_loss = np.inf
            self.best_weights = self.model.get_weights() # 最优权重初始化为模型的初始权重
            self.best_epoch = 1
            self.vlist = [] # 存储验证误差变化的列表
            self.tlist = [] # 存储训练误差变化的列表
            self.dwell = dwell
            self.factor = factor # 学习率衰减系数
            
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    通常一个回调中的方法有 on_train_begin, on_train_end, on_epoch_end, on_epoch_begin 等,它们并不是需要全部定义,我们可以根据自己的实际需求进行选择。我们定义的这个类就只使用了 on_train_begin, on_train_end, on_epoch_end 三种方法。我们来看看这三种方法都具体做了些什么。

    训练开始时,会给我们报告一些参数设置的情况,提示我们模型的训练流程,同时启动计时器。

        def on_train_begin(self, logs = None):
            if self.ask_epoch == 0:
                print('You set ask_epoch = 0, ask_epoch will be set to 1', flush = True)
                self.ask_epoch = 1
            if self.ask_epoch >= self.epochs: # 如果设置的 ask_epoch 比 epochs 还大,那就没有意义了
                print('ask_epoch >= epochs, will train for ', epochs, ' epochs', flush=True)
                self.ask = False
            if self.epochs == 1:
                self.ask = False
            else:
                 
                print(f'Training will proceed until epoch {ask_epoch} then you will be asked to')
                print('enter H to halt training or enter an integer for how many more epochs to run then be asked again')
                
                if self.dwell:
                    print('\n Learning rate will be automatically adjusted during training')
                    
            self.start_time = time.time() # 开始计时
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    训练结束后,模型会加载最优权重,并返回训练的总时间。

        def on_train_end(self, logs=None):
            print(f'Loading model with weights from epoch {self.best_epoch}')
            
            self.model.set_weights(self.best_weights)
            train_duration = time.time() - self.start_time
            hours = train_duration // 3600
            minutes = (train_duration - hours * 3600) // 60
            seconds = train_duration - hours * 3600 - minutes * 60
    
            print(f'Training using {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    可以看到,训练开始和训练结束的方法内容非常简单,如果不考虑可读性,那么省略不写也不会有太大影响。重点是下面的 on_epoch_end 方法。注释以及代码打印的内容已经很详细了,一行一行看下去肯定是没有问题的,这里不再过多解释。

        def on_epoch_end(self, epoch, logs=None):
            val_loss = logs.get('val_loss')
            loss = logs.get('loss')
            if epoch > 0:
                delta_v = self.lowest_vloss - val_loss # 该轮次的验证损失和最低验证损失的差值
                vimprov = (delta_v / self.lowest_vloss) * 100 # percentage of improvement,当然也有可能是负数,表示误差增高了
                self.vlist.append(vimprov)
                
                delta_t = self.lowest_loss - loss
                timprov = (delta_t / self.lowest_loss) * 100
                self.tlist.append(timprov)
            else:
                vimprov = 0.0
                timprov = 0.0
            
            if val_loss < self.lowest_vloss:
                self.lowest_vloss = val_loss # 更新最低验证误差
                self.best_weights = self.model.get_weights() # 以及相应的权重
                self.best_epoch = epoch + 1
                print(f'\n Validation loss of {val_loss:7.4f} is {vimprov:7.4f} % below lowest loss, saving weights from epoch {str(epoch + 1):3s} as best weights')
            else:
                vimprov = abs(vimprov)
                print(f'\n Validation loss of {val_loss:7.4f} is {vimprov:7.4f} % above lowest loss of {self.lowest_vloss:7.4f}. Keeping weights from epoch {str(self.best_epoch)} as best weights')
                
                if self.dwell:
                    lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
                    new_lr = lr * self.factor
                    print(f'\n Learning rate was automatically adjusted from {lr:8.6f} to {new_lr:8.6f}, model weights set to best weights')
                    
                    tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                    self.model.set_weights(self.best_weights) # 在新的学习率基础上,看模型在最优权重上表现如何
            
            if loss < self.lowest_loss:
                self.lowest_loss = loss
                
            if self.ask:
                if epoch + 1 == self.ask_epoch:
                    print('\n Enter H to end training or an integer for the number of additional epochs to run then ask again')
                    ans = input()
                    
                    if ans == 'H' or ans == 'h' or ans == '0': # 放弃训练
                        self.model.stop_training = True
                    else:
                        self.ask_epoch += int(ans) # 在第 ask_epoch+ans 轮次再次询问
                        if self.ask_epoch > self.epochs:
                            print('\n Your specification exceeds ', self.epochs, ' cannot train for ', self.ask_epoch, flush =True)
                        else:
                            print(f'\n You entered {ans}. Training will continue to epoch {self.ask_epoch}')
                            
                            if self.dwell == False:
                                lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) 
                                print(f'\n Current LR is  {lr:8.6f}  hit enter to keep  this LR or enter a new LR')
                                
                                ans = input(' ')
                                if ans == '':
                                    print(f'\n Keeping current LR of {lr:7.5f}')
                                    
                                else:
                                    new_lr = float(ans)
                                    tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                                    print(f'\n Changing LR to {ans}')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61

    事实上,这个回调实现的功能与 keras 本身含有的回调可能有相似部分,但重点在于理解一个 callback 的自定义过程。

    最后,我们实例化这个回调,并添加到回调列表中。

    epochs = 50
    ask_epoch = 10
    ask = My_ASK(model, epochs, ask_epoch)
    callbacks = [ask]
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    子组件监听父组件消息,随之变化与不变化
    redis爆满导致数据丢失
    代码随想录二刷 day01 | 704. 二分查找 27. 移除元素 26.删除有序数组中的重复项 80. 删除有序数组中的重复项 II
    【ADB】借助ADB模拟滑动屏幕,并进行循环
    【Java】线程池源码解析
    选低代码开发的OA系统,对低效办公说“漏”
    【信号去噪】基于Sage-Husa自适应卡尔曼滤波器实现海浪磁场噪声抑制及海浪磁场噪声的产生附matlab代码
    C#:实现鸡尾酒定向冒泡排序算法(附完整源码)
    Netty笔记
    博客自动化测试
  • 原文地址:https://blog.csdn.net/myDarling_/article/details/128056116