enabe tp thru tp_size

This commit is contained in:
Sunny Liu
2025-02-18 16:20:24 -05:00
committed by Sung Ching Liu
parent 75cbd15301
commit dbdf97e828
2 changed files with 12 additions and 0 deletions

View File

@@ -703,6 +703,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
if self.cfg.tp_size is not None:
training_arguments_kwargs["tp_size"] = self.cfg.tp_size
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:

View File

@@ -748,6 +748,8 @@ class AxolotlInputConfig(
local_rank: Optional[int] = None
ddp: Optional[bool] = None
tp_size: Optional[int] = None
seed: Optional[int] = None
ddp_timeout: Optional[int] = None
ddp_bucket_cap_mb: Optional[int] = None
@@ -1371,6 +1373,13 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_tp(cls, data):
if data.get("fsdp") and data.get("tp_size"):
raise ValueError("FSDP is not compatible with tensor parallelism")
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (