From 5b6690ac25e61b76fcc06132b11b3d61464c958a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 May 2023 01:44:12 +0900 Subject: [PATCH 1/2] Fix condition scheduler --- src/axolotl/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6535c2a7e..715c32555 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -104,7 +104,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, optim=cfg.optimizer if cfg.optimizer else None, - lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", + lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler in ("one_cycle", "log_sweep") else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, **training_arguments_kwargs, ) From 36aaea02b99200a009d70943f24c8673f8e0711e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 May 2023 02:01:08 +0900 Subject: [PATCH 2/2] Update trainer.py --- src/axolotl/utils/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 715c32555..e774a533e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -103,8 +103,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, - optim=cfg.optimizer if cfg.optimizer else None, - lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler in ("one_cycle", "log_sweep") else "cosine", + optim=cfg.optimizer if cfg.optimizer else "adamw_hf", + lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, **training_arguments_kwargs, )