diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8553339b9..755d60908 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -543,10 +543,18 @@ def setup_fsdp_envs(cfg): ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap +def setup_tp_envs(): + os.environ["ACCELERATE_USE_TP"] = "true" + + def prepare_optim_env(cfg): if not check_cuda_p2p_ib_support(): if os.getenv("NCCL_P2P_DISABLE") is None: os.environ["NCCL_P2P_DISABLE"] = "1" + + if cfg.tp_size > 1: + setup_tp_envs() + if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: