only patch in CP > 1 case
This commit is contained in:
@@ -76,9 +76,6 @@ class PatchManager:
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
patch_evaluation_loop,
|
||||
patch_maybe_log_save_evaluate,
|
||||
@@ -86,7 +83,13 @@ class PatchManager:
|
||||
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
|
||||
Reference in New Issue
Block a user