setup env vars for ray train for FSDP (#3130) [skip ci]

This commit is contained in:
Wing Lian
2025-09-16 14:52:29 -04:00
committed by GitHub
parent 58d67bf98d
commit 1ef6c196f7

View File

@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import prepare_optim_env
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
@@ -92,6 +93,7 @@ def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
prepare_optim_env(cfg)
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype