From 4caa59a087b9a7faf12bf6503d95cff43b62fcd2 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 20 Feb 2025 11:36:55 -0500 Subject: [PATCH] auto detect tp_size --- src/axolotl/core/trainer_builder.py | 4 ++-- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e2b9b703a..a56c2db39 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -703,8 +703,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "accelerator_config" ] = self.cfg.accelerator_config - if self.cfg.tp_size is not None: - training_arguments_kwargs["tp_size"] = self.cfg.tp_size + if self.cfg.tensor_parallel: + training_arguments_kwargs["tp_size"] = torch.cuda.device_count() if self.cfg.kd_ce_alpha is not None: training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c5821b8f1..840991e4e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -748,7 +748,7 @@ class AxolotlInputConfig( local_rank: Optional[int] = None ddp: Optional[bool] = None - tp_size: Optional[int] = None + tensor_parallel: Optional[bool] = None seed: Optional[int] = None ddp_timeout: Optional[int] = None @@ -1376,7 +1376,7 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_fsdp_tp(cls, data): - if data.get("fsdp") and data.get("tp_size"): + if data.get("fsdp") and data.get("tensor_parallel"): raise ValueError("FSDP is not compatible with tensor parallelism") return data