From 612aabd8c468b6f1aeda80fdec5ec4a4bc3ae159 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 27 Jun 2023 15:40:25 -0400 Subject: [PATCH] push intermediate model checkpoints to hub --- src/axolotl/prompt_strategies/alpaca_chat.py | 11 ++++++++++- src/axolotl/utils/trainer.py | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 6161d7e37..952a55961 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, ) -from axolotl.prompters import AlpacaPrompter, PromptStyle +from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter def load(tokenizer, cfg): @@ -103,3 +103,12 @@ def load_camel_ai(tokenizer, cfg): cfg.train_on_inputs, cfg.sequence_len, ) + + +def load_no_prompt(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + UnpromptedPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5cf3107f3..e9ec641a6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -124,6 +124,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.max_grad_norm: training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm + if cfg.push_to_hub_model_id: + training_arguments_kwargs["push_to_hub_model_id"] = cfg.push_to_hub_model_id + training_arguments_kwargs["push_to_hub"] = True + training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size