From 71b7ea3c056f15123b56fef3151b4044c80078b4 Mon Sep 17 00:00:00 2001 From: kallewoof Date: Wed, 29 Nov 2023 22:36:35 +0900 Subject: [PATCH] Determine FSDP/deepspeed settings on device select. (#883) * Determine FSDP/deepspeed settings on device select. Without this, the OS env check for accelerate will fail. * rename and move env setup call * chore: lint --------- Co-authored-by: Karl-Johan Alm Co-authored-by: Wing Lian --- src/axolotl/cli/__init__.py | 3 +++ src/axolotl/utils/trainer.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index c7d7932da..7ce4f1948 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels +from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): validate_config(cfg) + prepare_optim_env(cfg) + normalize_config(cfg) setup_wandb_env_vars(cfg) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6d09a4559..469f6d886 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg): ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): +def prepare_optim_env(cfg): if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + +def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset