From 2f8037fee6cdee318df216049d0923455d80dad6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 13:10:40 -0400 Subject: [PATCH] ensure that the hftrainer deepspeed config is set before the trainer class is ever init'ed (#1850) [skip ci] --- src/axolotl/utils/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 26796f2e5..99c10c655 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg): def setup_deepspeed_env(cfg, stage=None): + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed if stage: os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + HfTrainerDeepSpeedConfig(cfg.deepspeed) def setup_fsdp_envs(cfg):