add support for alpaca reflect training (#2)

This commit is contained in:
Wing Lian
2023-04-18 08:34:05 -04:00
committed by GitHub
parent 34af1b465f
commit 81de0efc18
4 changed files with 147 additions and 2 deletions

View File

@@ -37,9 +37,9 @@ from axolotl.prompt_tokenizers import (
ShareGPTPromptTokenizingStrategy,
LLAMA_DEFAULT_PAD_TOKEN,
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
)
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -395,6 +395,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
@@ -540,6 +541,15 @@ def train(
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len