diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 5e766de37..8d33c0b84 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager from axolotl.train import train from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault +from axolotl.utils.trainer import prepare_optim_env def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): @@ -92,6 +93,7 @@ def ray_train_func(kwargs: dict): # cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict) # also renormalize the config now that TorchTrainer has spawned distributed workers cfg = DictDefault(kwargs["cfg"]) + prepare_optim_env(cfg) normalize_config(cfg) # now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype