From 1ef6c196f7d1cffb2010accd4f0ef716ddab405a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Sep 2025 14:52:29 -0400 Subject: [PATCH] setup env vars for ray train for FSDP (#3130) [skip ci] --- src/axolotl/cli/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 5e766de37..8d33c0b84 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager from axolotl.train import train from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault +from axolotl.utils.trainer import prepare_optim_env def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): @@ -92,6 +93,7 @@ def ray_train_func(kwargs: dict): # cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict) # also renormalize the config now that TorchTrainer has spawned distributed workers cfg = DictDefault(kwargs["cfg"]) + prepare_optim_env(cfg) normalize_config(cfg) # now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype