diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 6161d7e37..952a55961 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, ) -from axolotl.prompters import AlpacaPrompter, PromptStyle +from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter def load(tokenizer, cfg): @@ -103,3 +103,12 @@ def load_camel_ai(tokenizer, cfg): cfg.train_on_inputs, cfg.sequence_len, ) + + +def load_no_prompt(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + UnpromptedPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5cf3107f3..e9ec641a6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -124,6 +124,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.max_grad_norm: training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm + if cfg.push_to_hub_model_id: + training_arguments_kwargs["push_to_hub_model_id"] = cfg.push_to_hub_model_id + training_arguments_kwargs["push_to_hub"] = True + training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size