ensure that the hftrainer deepspeed config is set before the trainer class is ever init'ed (#1850) [skip ci]
This commit is contained in:
@@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg):
|
||||
|
||||
|
||||
def setup_deepspeed_env(cfg, stage=None):
|
||||
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
if stage:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||
if stage == 3:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||
HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
|
||||
Reference in New Issue
Block a user