Compare commits

...

10 Commits

Author SHA1 Message Date
Sung Ching Liu
f68aedd1f8 Update __init__.py 2025-02-26 00:21:16 -05:00
Sunny Liu
3dd5c6f8ec nit 2025-02-26 00:21:16 -05:00
Sunny Liu
4caa59a087 auto detect tp_size 2025-02-26 00:21:16 -05:00
Sunny Liu
984be14147 add tp_size in config doc 2025-02-26 00:21:16 -05:00
Sunny Liu
64adbf1a15 tp plan not needed 2025-02-26 00:21:16 -05:00
Sunny Liu
438b623031 prepare accelerate envs for tp 2025-02-26 00:21:16 -05:00
Sunny Liu
a74efcecbe skip move to device 2025-02-26 00:21:16 -05:00
Sunny Liu
d663652216 del device_map for tp 2025-02-26 00:21:16 -05:00
Sunny Liu
dbd43aa18f set tp_plan 2025-02-26 00:21:16 -05:00
Sunny Liu
dbdf97e828 enabe tp thru tp_size 2025-02-26 00:21:16 -05:00
5 changed files with 19 additions and 0 deletions

View File

@@ -78,6 +78,9 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere
float16: true
# Use Tensor parallel
tensor_parallel: true # require multi-gGPU
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge

View File

@@ -703,6 +703,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
if self.cfg.tensor_parallel:
training_arguments_kwargs["tp_size"] = torch.cuda.device_count()
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:

View File

@@ -748,6 +748,8 @@ class AxolotlInputConfig(
local_rank: Optional[int] = None
ddp: Optional[bool] = None
tensor_parallel: Optional[bool] = None
seed: Optional[int] = None
ddp_timeout: Optional[int] = None
ddp_bucket_cap_mb: Optional[int] = None
@@ -1371,6 +1373,13 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_tp(cls, data):
if data.get("fsdp") and data.get("tensor_parallel"):
raise ValueError("FSDP with tensor parallelism is not supported yet.")
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (

View File

@@ -762,6 +762,9 @@ class ModelLoader:
return hf_ds_cfg
skip_move_to_device = False
if self.cfg.tensor_parallel:
del self.model_kwargs["device_map"]
if ( # pylint: disable=condition-evals-to-constant)
(self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
and not qlora_fsdp

View File

@@ -547,6 +547,7 @@ def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
os.environ["NCCL_P2P_DISABLE"] = "1"
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed: