小模型+参数量少+单卡跑不需要服务器,尝试了一下ray tune不是很适合。。而且很难用。。
再三尝试后,决定使用optuna,选择的原因:
我本来是调参的。。结果却调了很多调参的工具老半天,所以分享一些零碎的经验和踩过的坑,查看本文之前最好已经对optuna(或者其他调参工具)的使用方式有一个基本的了解喔,不要太指望这个写的很碎的教程能帮你从0起步…
# trainable params
parameters = dict(lr=[.01,.001],batch_size = [100,1000],shuffle = [True,False])
#创建可传递给product函数的可迭代列表
param_values = [v for v in parameters.values()]
#把各个列表进行组合,得到一系列组合的参数
#*号是告诉乘积函数把列表中每个值作为参数,而不是把列表本身当做参数来对待
for lr,batch_size,shuffle in product(*param_values):comment = f'batch_size={batch_size}lr={lr}shuffle={shuffle}'#这里写你调陈的主程序即可print(comment)
optimize函数与suggest_float的一个demoimport optuna
def objective(trial):x = trial.suggest_float("x", 0, 10)return x**2
study = optuna.create_study()
study.optimize(objective, n_trials=3,show_progress_bar=True)
optimize参数:
objective: objecticve函数,就是包装一下training的过程,具体参考其他文档n_trials: objecticve函数执行的次数,每次执行都会抽取一个x,抽取规则是suggest_floatshow_progress_bar:多输出一点tuning的进展信息,默认是False,其实设置为True也不会有什么有价值的信息,就像tqdm一样会告诉你现在进行到第几个,还剩几个。suggest_float函数
官方文档,值得参考:
含义:从0和10中抽取一个float数返回给x,当然如果想返回一个int,使用suggest_int
optimize会执行objective函数n_trials次,按照官方的写法,是不是每次执行都会重新抽取执行各种random程序: objective function,写成Objective class,因此 objective = Objective(params)objective函数(不然只能传一个trial),二是self.attr的值是不会变的self.attr的值是不会变的,刚好解决了我需要的一切问题import optuna
import numpy as np
class Objective:def __init__(self, min_x, max_x):# Hold this implementation specific arguments as the fields of the class.self.min_x = min_xself.max_x = max_x# 注意这里的值不会变喔self.test_randn = np.random.randn(7)# 这个trial是必须的(也是唯一的?)def __call__(self, trial):# Calculate an objective value by using the extra arguments.x = trial.suggest_float("x", self.min_x, self.max_x)print(self.test_randn)return (x - 2) ** 2# Execute an optimization by using an `Objective` instance.
# 调用100次Objective function,self.test_randn是不会变的
study = optuna.create_study()
study.optimize(Objective(-100, 100), n_trials=100)
Objective class 大概这样:class Objective:# 传递dataset以及opt,后者是一个dict,存放了各种不需要tune的参数def __init__(self, dataset, opt):# Hold this implementation specific arguments as the fields of the class.self.dataset = datasetself.opt = opt# Hold the data split!!self.shuffled_indices = save_data_idx(dataset,opt)def __call__(self, trial):# Calculate an objective value by using the extra arguments.# 需要tune的参数config = {'learning_rate': trial.suggest_categorical('learning_rate', [5e-2, 1e-2, 5e-3]),'lr_for_pi': trial.suggest_categorical('lr_for_pi', [1e-2, 5e-2, 1e-3])}print("idx check: ",self.shuffled_indices[0:5])# 每次split出来的data都是一致的train_loader, val_loader, test_loader = get_data_loader(self.dataset, self.shuffled_indices, self.opt)model = MLP(self.opt.N_gaussians).to(device) performance = trainer(train_loader, val_loader, model, config, self.opt, device)return performance
grid search,做了一些必要的修改,其实感觉还是有点笨重trial.suggest_categorical,而不是什么int或者float,后面的list存放你想尝试的几个数据,比如[5e-2, 1e-2, 5e-3]就是我想尝试的3个数据 config = {'learning_rate': trial.suggest_categorical('learning_rate', [5e-2, 1e-2, 5e-3]),'lr_for_pi': trial.suggest_categorical('lr_for_pi', [1e-2, 5e-2, 1e-3])}
study时,加上参数sampler,并且选取GridSampler# 里面所有的组合被cover之后会自动stop
sampler = optuna.samplers.GridSampler(search_space={'learning_rate': [5e-2, 1e-2, 5e-3], # 注意这里和config里保持一致'lr_for_pi': [1e-2, 5e-2, 1e-3] # 注意这里和config里保持一致})
study = optuna.create_study(study_name=study_name,direction='minimize',storage=storage_name,load_if_exists=True,sampler=sampler,pruner=pruner)
study.optimize(Objective(dataset), n_trials=100,show_progress_bar=True)
sampler 里面的搜索空间search_space和上面的config保持一致GridSampler的官方文档非常值得一读:n_trials==100,只要搜索完了搜索空间search_space里的全部组合,就会自动停止,比如这里只需要搜索9个参数组合,那么执行9次之后就会自动停止config不是suggest_categorical,也可以进行网格搜索,那么依然会等cover全部组合之后自动停止,因此这个时候的试探次可能不止9次optuna.pruners.MedianPruner,这个的剪枝策略不一定最好但是足够通用,具体可以参考官方文档,optuna.pruners.NopPruner():pruner = optuna.pruners.NopPruner()
study = optuna.create_study(study_name=study_name,direction='minimize',storage=storage_name,load_if_exists=True,sampler=sampler,pruner=pruner)
study_name指定了数据库文件的名字,如果不指定会默认生成一个,但是注意这个名字的命名规则不允许有空格喔optuna.create_study(study_name=study_name,direction='minimize',storage=storage_name,load_if_exists=True,sampler=sampler,pruner=pruner)日后希望项目结束可以放上全部代码。希望大家也能留下自己的optuna使用经验。
下一篇:2023-03-04 反思