ensure that the hftrainer deepspeed config is set before the trainer class is ever init'ed (#1850) [skip ci]

This commit is contained in:
Wing Lian
2024-08-22 13:10:40 -04:00
committed by GitHub
parent de4ea2d1f2
commit 2f8037fee6

View File

@@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
HfTrainerDeepSpeedConfig(cfg.deepspeed)
def setup_fsdp_envs(cfg):