keras:callback 专属定制!来实现一个自己的 callback 吧!
创始人
2024-02-17 13:26:23
0

我们在一个深度模型的训练中经常会用到回调函数来对训练过程进行监测,使得训练过程更加智能化。

例如,我们经常使用的早停机制:

from tensorflow.keras.callbacks import EarlyStoppingearly_stop = EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True, verbose=1)

通过监测验证误差的变化趋势,我们可以在验证误差不再增长的时候提前结束训练。

另一个与 EarlyStopping 常常配合使用的是 ReduceLROnPlateau,当指定的训练误差或者验证误差在指定的轮次以内不再增长的时候,我们将学习率根据设置的衰减系数 factor 自动降低:

from keras.callbacks import ReduceLROnPlateaulearning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', patience=3, verbose=1, factor=0.5, min_lr=0.00001)

keras 的回调 API 包含许多不同功能用途的回调函数,通常这些回调就可以满足我们的需求了。但如果我们想要更加精细的控制训练过程,可能需要写一个自己的回调。

我们接下来就实现一个,可以在训练过程中自由控制训练轮次和学习率的回调。这个回调的功能主要是:

  • 在指定轮次之后询问使用者,是否继续训练,如果继续训练,键入继续训练的轮次,并选择保持或者改变当前学习率
  • 如果验证误差增加,则自动调整学习率,且模型加载当前最优的权重
  • 训练结束后,直接让模型加载最优权重

我们需要定义一个类,这个类继承 keras.callbacks.Callback,然后做一些初始化:

class My_ASK(keras.callbacks.Callback):def __init__(self, model, epochs, ask_epoch, dwell=True, factor=.4):super(My_ASK, self).__init__()self.model = model"""模型在训练 ask_epoch 之后,会让使用者选择是暂停训练还是继续训练,如果继续训练,则直接输入一个整数,表明继续训练的轮次,且会给我们修改学习率的机会"""self.ask_epoch = ask_epochself.epochs = epochsself.ask = True # 将 ask 设为 True 才会有上面 ask_epoch 描述的询问self.lowest_vloss = np.infself.lowest_loss = np.infself.best_weights = self.model.get_weights() # 最优权重初始化为模型的初始权重self.best_epoch = 1self.vlist = [] # 存储验证误差变化的列表self.tlist = [] # 存储训练误差变化的列表self.dwell = dwellself.factor = factor # 学习率衰减系数

通常一个回调中的方法有 on_train_begin, on_train_end, on_epoch_end, on_epoch_begin 等,它们并不是需要全部定义,我们可以根据自己的实际需求进行选择。我们定义的这个类就只使用了 on_train_begin, on_train_end, on_epoch_end 三种方法。我们来看看这三种方法都具体做了些什么。

训练开始时,会给我们报告一些参数设置的情况,提示我们模型的训练流程,同时启动计时器。

    def on_train_begin(self, logs = None):if self.ask_epoch == 0:print('You set ask_epoch = 0, ask_epoch will be set to 1', flush = True)self.ask_epoch = 1if self.ask_epoch >= self.epochs: # 如果设置的 ask_epoch 比 epochs 还大,那就没有意义了print('ask_epoch >= epochs, will train for ', epochs, ' epochs', flush=True)self.ask = Falseif self.epochs == 1:self.ask = Falseelse:print(f'Training will proceed until epoch {ask_epoch} then you will be asked to')print('enter H to halt training or enter an integer for how many more epochs to run then be asked again')if self.dwell:print('\n Learning rate will be automatically adjusted during training')self.start_time = time.time() # 开始计时

训练结束后,模型会加载最优权重,并返回训练的总时间。

    def on_train_end(self, logs=None):print(f'Loading model with weights from epoch {self.best_epoch}')self.model.set_weights(self.best_weights)train_duration = time.time() - self.start_timehours = train_duration // 3600minutes = (train_duration - hours * 3600) // 60seconds = train_duration - hours * 3600 - minutes * 60print(f'Training using {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds')

可以看到,训练开始和训练结束的方法内容非常简单,如果不考虑可读性,那么省略不写也不会有太大影响。重点是下面的 on_epoch_end 方法。注释以及代码打印的内容已经很详细了,一行一行看下去肯定是没有问题的,这里不再过多解释。

    def on_epoch_end(self, epoch, logs=None):val_loss = logs.get('val_loss')loss = logs.get('loss')if epoch > 0:delta_v = self.lowest_vloss - val_loss # 该轮次的验证损失和最低验证损失的差值vimprov = (delta_v / self.lowest_vloss) * 100 # percentage of improvement,当然也有可能是负数,表示误差增高了self.vlist.append(vimprov)delta_t = self.lowest_loss - losstimprov = (delta_t / self.lowest_loss) * 100self.tlist.append(timprov)else:vimprov = 0.0timprov = 0.0if val_loss < self.lowest_vloss:self.lowest_vloss = val_loss # 更新最低验证误差self.best_weights = self.model.get_weights() # 以及相应的权重self.best_epoch = epoch + 1print(f'\n Validation loss of {val_loss:7.4f} is {vimprov:7.4f} % below lowest loss, saving weights from epoch {str(epoch + 1):3s} as best weights')else:vimprov = abs(vimprov)print(f'\n Validation loss of {val_loss:7.4f} is {vimprov:7.4f} % above lowest loss of {self.lowest_vloss:7.4f}. Keeping weights from epoch {str(self.best_epoch)} as best weights')if self.dwell:lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))new_lr = lr * self.factorprint(f'\n Learning rate was automatically adjusted from {lr:8.6f} to {new_lr:8.6f}, model weights set to best weights')tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)self.model.set_weights(self.best_weights) # 在新的学习率基础上,看模型在最优权重上表现如何if loss < self.lowest_loss:self.lowest_loss = lossif self.ask:if epoch + 1 == self.ask_epoch:print('\n Enter H to end training or an integer for the number of additional epochs to run then ask again')ans = input()if ans == 'H' or ans == 'h' or ans == '0': # 放弃训练self.model.stop_training = Trueelse:self.ask_epoch += int(ans) # 在第 ask_epoch+ans 轮次再次询问if self.ask_epoch > self.epochs:print('\n Your specification exceeds ', self.epochs, ' cannot train for ', self.ask_epoch, flush =True)else:print(f'\n You entered {ans}. Training will continue to epoch {self.ask_epoch}')if self.dwell == False:lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) print(f'\n Current LR is  {lr:8.6f}  hit enter to keep  this LR or enter a new LR')ans = input(' ')if ans == '':print(f'\n Keeping current LR of {lr:7.5f}')else:new_lr = float(ans)tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)print(f'\n Changing LR to {ans}')

事实上,这个回调实现的功能与 keras 本身含有的回调可能有相似部分,但重点在于理解一个 callback 的自定义过程。

最后,我们实例化这个回调,并添加到回调列表中。

epochs = 50
ask_epoch = 10
ask = My_ASK(model, epochs, ask_epoch)
callbacks = [ask]

相关内容

热门资讯

原创 戴... 最近,关于前国脚戴琳的欠薪丑闻无疑是引发了球迷的持续关注,从10月25日,媒体人李平康率先爆料,晒出...
思想政治工作条例最新修订内容,... 思想政治工作条例最新修订内容,思想政治工作条例全文下载 思想政治工作条例最新修订,全文下载与深度解读...
CBA潜力赛为何打成“老将赛”... 计时钟归零,双方教练握手致意,观众开始退场,CBA联赛的正赛宣告结束。然而球场并未就此沉寂,替补席上...
“手术钻头断裂遗留患者体内”,... 12月21日,湖南祁阳市卫生健康局发布情况通报称,近日,有媒体报道祁阳市中医医院发生骨科手术钻头断裂...
代驾纠纷 代驾时撞伤行人、车辆发生故障…… 这些都和车主无关,应由代驾赔偿? 观点: 使用代驾服务并非将所有...
公司股东与妻子分居期间出轨女下... 近日据报道,宁夏永宁县人民法院一审查明公司股东李某乙在与妻子李某甲分居期间,与公司女员工马某某存在不...
动物学家、律师和创作者,Thi... 12月21日,以“一起·了不起”为主题的2025 ThinkPad黑FUN礼在京举办。活动现场,律师...
徐奇渊:扩内需与对外政策紧密相... 近日,中国海关总署发布了一组数据令人关注:2025年前11个月,我国货物贸易顺差达到1.08万亿美元...
46岁上海独居女子不幸离世,官... 居住在上海虹口区46岁的蒋女士因突发脑溢血于今年10月入院,远亲吴先生与其公司共同垫付了医药费,但她...