add support for alpaca reflect training (#2)
This commit is contained in:
@@ -37,9 +37,9 @@ from axolotl.prompt_tokenizers import (
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
LLAMA_DEFAULT_PAD_TOKEN,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter
|
||||
|
||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
@@ -395,6 +395,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
)
|
||||
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
cfg.early_stopping_patience,
|
||||
@@ -540,6 +541,15 @@ def train(
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
datasets.append(ds_wrapper)
|
||||
elif d.type == "reflection":
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
ReflectAlpacaPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
datasets.append(ds_wrapper)
|
||||
elif d.type == "sharegpt":
|
||||
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
|
||||
Reference in New Issue
Block a user