From b9a3bfee5a4b8badc88b51b8d953c306ae46184b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 24 Sep 2025 13:36:14 -0400 Subject: [PATCH] only patch in CP > 1 case --- src/axolotl/loaders/patch_manager.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index e2e8f3e68..1e46f5c34 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -76,9 +76,6 @@ class PatchManager: self._apply_tiled_mlp(self.cfg.model_config_type) def _apply_transformers_patches(self): - from axolotl.monkeypatch.transformers.trainer_context_parallel import ( - patch_prepare_context_parallel_inputs, - ) from axolotl.monkeypatch.transformers.trainer_loss_calc import ( patch_evaluation_loop, patch_maybe_log_save_evaluate, @@ -86,7 +83,13 @@ class PatchManager: patch_evaluation_loop() patch_maybe_log_save_evaluate() - patch_prepare_context_parallel_inputs() + + if self.cfg.context_parallel_size > 1: + from axolotl.monkeypatch.transformers.trainer_context_parallel import ( + patch_prepare_context_parallel_inputs, + ) + + patch_prepare_context_parallel_inputs() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance."""