Merge branch 'OpenAccess-AI-Collective:main' into logging_enhancement

This commit is contained in:
The Objective Dad
2023-07-15 06:16:04 -05:00
committed by GitHub
2 changed files with 9 additions and 0 deletions

View File

@@ -305,6 +305,8 @@ base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files, # if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model # you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
tokenizer_config: tokenizer_config:
@@ -411,6 +413,9 @@ logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet) # don't use this, leads to wonky training (according to someone on the internet)

View File

@@ -183,6 +183,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.hub_model_id: if cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( training_args = AxolotlTrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,