only patch in CP > 1 case

This commit is contained in:
Dan Saunders
2025-09-24 13:36:14 -04:00
parent 08124a7c92
commit b9a3bfee5a

View File

@@ -76,9 +76,6 @@ class PatchManager:
self._apply_tiled_mlp(self.cfg.model_config_type)
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
@@ -86,7 +83,13 @@ class PatchManager:
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
patch_prepare_context_parallel_inputs()
if self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""