ensure that the hftrainer deepspeed config is set before the trainer class is ever init'ed (#1850) [skip ci]
This commit is contained in:
@@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_deepspeed_env(cfg, stage=None):
|
def setup_deepspeed_env(cfg, stage=None):
|
||||||
|
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||||
|
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
if stage:
|
if stage:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||||
|
|
||||||
|
|
||||||
def setup_fsdp_envs(cfg):
|
def setup_fsdp_envs(cfg):
|
||||||
|
|||||||
Reference in New Issue
Block a user