diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 12346b8a2..e2b9b703a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -703,6 +703,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "accelerator_config" ] = self.cfg.accelerator_config + if self.cfg.tp_size is not None: + training_arguments_kwargs["tp_size"] = self.cfg.tp_size + if self.cfg.kd_ce_alpha is not None: training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha if self.cfg.kd_alpha is not None: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c7803b8cc..c5821b8f1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -748,6 +748,8 @@ class AxolotlInputConfig( local_rank: Optional[int] = None ddp: Optional[bool] = None + tp_size: Optional[int] = None + seed: Optional[int] = None ddp_timeout: Optional[int] = None ddp_bucket_cap_mb: Optional[int] = None @@ -1371,6 +1373,13 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_fsdp_tp(cls, data): + if data.get("fsdp") and data.get("tp_size"): + raise ValueError("FSDP is not compatible with tensor parallelism") + return data + @model_validator(mode="after") def check_fft_possible_bad_config(self): if (