auto detect tp_size
This commit is contained in:
committed by
Sung Ching Liu
parent
984be14147
commit
4caa59a087
@@ -703,8 +703,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
|
||||
if self.cfg.tp_size is not None:
|
||||
training_arguments_kwargs["tp_size"] = self.cfg.tp_size
|
||||
if self.cfg.tensor_parallel:
|
||||
training_arguments_kwargs["tp_size"] = torch.cuda.device_count()
|
||||
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
|
||||
@@ -748,7 +748,7 @@ class AxolotlInputConfig(
|
||||
local_rank: Optional[int] = None
|
||||
ddp: Optional[bool] = None
|
||||
|
||||
tp_size: Optional[int] = None
|
||||
tensor_parallel: Optional[bool] = None
|
||||
|
||||
seed: Optional[int] = None
|
||||
ddp_timeout: Optional[int] = None
|
||||
@@ -1376,7 +1376,7 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_tp(cls, data):
|
||||
if data.get("fsdp") and data.get("tp_size"):
|
||||
if data.get("fsdp") and data.get("tensor_parallel"):
|
||||
raise ValueError("FSDP is not compatible with tensor parallelism")
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user