Did you know that you can navigate the posts by swiping left and right?

NNI 自动调参工具

27 Jun 2022 . category: tech .

官方文档:https://nni.readthedocs.io/zh/stable/

官方文档中的命令行自动调参写的并不十分明白,为避免后来人踩坑,做此纪录。 总结一下,按三步走即可,十分方便。

step.1 修改源代码

假设源文件train.py中的超参数由如下代码获得

args = args.parse_args()
a = args.P_a
b = args.P_b
c = args.P_c

那么你日常训练model的命令为:

python train.py --P_a 0.1 --P_b 1 --P_c niubi

利用NNI进行自动调参时,需要修改源码,如下示例。注意get_next_parameter()返回字典类型!!!

import nni
args= nni.get_next_parameter()
a = args["P_a"] # 字典类型!!!
b = args["P_b"]
c = args["P_c"]

然后在评估模型时,用nni.report_intermediate_result()提交当前结果 用nni.report_final_result()提交最终(一般为最优)结果

nni.report_intermediate_result(score)
nni.report_final_result(best_score)

step.2 创建config.yaml

参考该说明设置需要调整超参数的typevalue

search_space:
  P_a:
    _type: uniform
    _value: [ 0, 1 ]
  P_a:
    _type: quniform
    _value: [ 0, 5, 1]
  P_c:
    _type: choice
    _value: ["niubi","niubi_plus"]

trial_command: python train.py
trial_code_directory: .

trial_concurrency: 1
max_trial_number: 100

tuner:
  name: TPE
  class_args:
    optimize_mode: maximize

training_service:
  platform: local

step.3 运行NNICTL命令

运行nnictl create,端口可自行设置

nnictl create --config config.yaml --port 8080

打开网页http://127.0.0.1:8080(或远程服务器主机地址:8080)即可看到运行情况

后记

如果运行Failed,可以参考下图看看命令行输出的报错信息,快速定位错误原因。 再次感谢microsoft的NNI团队为炼丹事业所做出的巨大贡献!