parser
,解析器,可以让人轻松编写用户友好的命令行接口sys.argv
识别解析用户给出的参数,方面这些参数后续 .py
文件中的操作parser
是 python
直接支持使用的库,使用import parser
即可导入使用
parser
的人越来越少了,为什么呢?argparser
import argparser
argparser
主要是通过实例化一个 ArgumentParser
类来做各种操作的:parser = argparse.ArgumentParser(
prog='ProgramName',
description='What the program does',
epilog='Text at the bottom of help')
prog
:描述项目名description
:描述项目作用epilog
:在参数帮助信息之后显示的文本f(12,b=5)
,前面就是位置参数,后面是选项参数-f
和长选项 --foo
,他们表示相同的含义,只不过约定俗成,短选项参数为一个 -
后接一个字母,据说特殊情况也可以接多个。但比如 -abc
,一般我们就认为它使用了三个短选项参数,即 -a -b -c
的简写-
开头,则识别为选项参数,否则其他都识别为位置参数。parser.add_argument('filename') # 位置参数
parser.add_argument('-c', '--count') # 选项参数,其中 -c 和 --count 是一个含义
parser.add_argument('-v', '--verbose',
action='store_true')
parser.add_argument('--foo', help='foo help') # 只提供长选项参数
.sh
中,提供参数:OUTPUT_DIR=${1:-"./llama-2-7b-oscar-ft"}
export HF_DATASETS_CACHE=".cache/huggingface_cache/datasets"
export TRANSFORMERS_CACHE=".cache/models/"
# random port between 30000 and 50000
port=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))
accelerate launch --main_process_port ${port} --config_file configs/deepspeed_train_config.yaml \
run_llmmt.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--oscar_data_path oscar-corpus/OSCAR-2301 \
--oscar_data_lang en,ru,cs,zh,is,de \
--interleave_probs "0.17,0.22,0.14,0.19,0.08,0.2" \
--streaming \
--max_steps 600000 \
--do_train \
..... 太长省略
然后在 .py
文件中,直接调用
args = parser.parse_args()
※ 然后就可以随意使用其中的参数啦
print(args.seed)
print(args.do_train)
--foo
和位置参数 bar
['BAR']
,默认作为位置参数,所以有 bar='BAR'
['BAR', '--foo', 'FOO']
,第一个为位置参数,所以有 bar='BAR'
第二个选项参数 --foo
设置为 FOO
.split()
也是同理的,但这个更接近于 cmd / sh
的格式action
action='store_const'
,那么使用 const=xxx
来设置存储的常数。这样调用 --foo
的话,最终foo=42
,不然 foo=None
action='store_true'
,那么如果有提供该参数,该参数值变为 true
,没提供该参数的话该参数值为 false
;store_false
同理action='apend'
的话,那么设定该参数是一个列表,我就可以提供多次该参数值,比如这里 --foo 1 --foo 2
,那么 foo=['1','2']
的列表了action='append_const'
的话,相当于是 const
和 append
的一个混合,同理。action='count'
的话,会返回调用该参数的次数nargs
可以设定该参数的接受参数个数('--foo', nargs=2)
,表示 --foo
后面需要接受俩参数,即 --foo a b
nargs='?'
,表示可以接受1个或0个。0个的时候会调用 default
的值nargs='*'
,表示可以接受任意数量个参数。nargs='+'
,表示可以接受1个或更多参数。type
设定接受参数的数据类型,例子有:choices=[...]
设定,该参数值是给定列表中的一个选项。否则报错。required=True
表示该参数必须提供,否则报错。help='str...'
,表示该参数的作用介绍%(var)s
就可以显示变量的值,比如 %(prog)s
或者 %(default)s
HfArgumentParser
是 HF
使用了 ArgumentParser
,为了更契合 HF 中的一些方法,做的一个工具from transformers import HfArgumentParser
sys.argv
,观察是通过 .json
传递参数(加载对应的json_file),还是通过 cmd / sh
里提供配置参数(解析成 dataclass)model_args, data_args, training_args
三个变量中去from utils.arguments import ModelArguments, DataTrainingArguments
from transformers import (
HfArgumentParser,
TrainingArguments,
default_data_collator,
)
from transformers import HfArgumentParser
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
tokenizer = load_tokenizer(data_args, model_args, training_args, logger)
train_datasets, eval_datasets, test_datasets = preprocess_cpo_data(train_raw_data, valid_raw_data, test_raw_data, pairs, tokenizer, shots_eval_dict, data_args, training_args, model_args)
model = load_model(data_args, model_args, training_args, tokenizer, logger)
trainer = CPOTrainer(
model,
args=training_args,
beta=model_args.cpo_beta,
train_dataset=train_datasets,
eval_dataset=eval_datasets,
tokenizer=tokenizer,
max_prompt_length=data_args.max_source_length,
max_length=data_args.max_source_length+data_args.max_new_tokens,
callbacks=[SavePeftModelCallback] if model_args.use_peft else None,
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
if model_args.use_peft:
if torch.distributed.get_rank() == 0:
model.save_pretrained(training_args.output_dir)
else:
trainer.save_model() # Saves the tokenizer too for easy upload