官方文档:https://nni.readthedocs.io/zh/stable/
官方文档中的命令行自动调参写的并不十分明白,为避免后来人踩坑,做此纪录。
总结一下,按三步走即可,十分方便。
假设源文件train.py
中的超参数由如下代码获得
args = args.parse_args()
a = args.P_a
b = args.P_b
c = args.P_c
那么你日常训练model的命令为:
python train.py --P_a 0.1 --P_b 1 --P_c niubi
利用NNI进行自动调参时,需要修改源码,如下示例。注意get_next_parameter()
返回字典类型!!!
import nni
args= nni.get_next_parameter()
a = args["P_a"] # 字典类型!!!
b = args["P_b"]
c = args["P_c"]
然后在评估模型时,用nni.report_intermediate_result()
提交当前结果
用nni.report_final_result()
提交最终(一般为最优)结果
nni.report_intermediate_result(score)
nni.report_final_result(best_score)
参考该说明设置需要调整超参数的type
和value
search_space:
P_a:
_type: uniform
_value: [ 0, 1 ]
P_a:
_type: quniform
_value: [ 0, 5, 1]
P_c:
_type: choice
_value: ["niubi","niubi_plus"]
trial_command: python train.py
trial_code_directory: .
trial_concurrency: 1
max_trial_number: 100
tuner:
name: TPE
class_args:
optimize_mode: maximize
training_service:
platform: local
运行nnictl create
,端口可自行设置
nnictl create --config config.yaml --port 8080
打开网页http://127.0.0.1:8080
(或远程服务器主机地址:8080)即可看到运行情况
如果运行Failed
,可以参考下图看看命令行输出的报错信息,快速定位错误原因。
再次感谢microsoft的NNI团队为炼丹事业所做出的巨大贡献!