setup env vars for ray train for FSDP (#3130) [skip ci]
This commit is contained in:
@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
@@ -92,6 +93,7 @@ def ray_train_func(kwargs: dict):
|
|||||||
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
||||||
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
||||||
cfg = DictDefault(kwargs["cfg"])
|
cfg = DictDefault(kwargs["cfg"])
|
||||||
|
prepare_optim_env(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user