diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index c7d7932da..7ce4f1948 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels +from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): validate_config(cfg) + prepare_optim_env(cfg) + normalize_config(cfg) setup_wandb_env_vars(cfg) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6d09a4559..469f6d886 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg): ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): +def prepare_optim_env(cfg): if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + +def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset