Feat(config): Add hub_strategy (#386)

This commit is contained in:
NanoCode012
2023-08-14 20:12:55 +09:00
committed by GitHub
parent 63fdb5a7fb
commit 73a0b6ead5
2 changed files with 6 additions and 0 deletions

View File

@@ -364,6 +364,9 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # repo path to push finetuned model hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean

View File

@@ -440,6 +440,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors: if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors