only patch in CP > 1 case
This commit is contained in:
@@ -76,9 +76,6 @@ class PatchManager:
|
|||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
|
||||||
patch_prepare_context_parallel_inputs,
|
|
||||||
)
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
patch_evaluation_loop,
|
patch_evaluation_loop,
|
||||||
patch_maybe_log_save_evaluate,
|
patch_maybe_log_save_evaluate,
|
||||||
@@ -86,7 +83,13 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_evaluation_loop()
|
patch_evaluation_loop()
|
||||||
patch_maybe_log_save_evaluate()
|
patch_maybe_log_save_evaluate()
|
||||||
patch_prepare_context_parallel_inputs()
|
|
||||||
|
if self.cfg.context_parallel_size > 1:
|
||||||
|
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||||
|
patch_prepare_context_parallel_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_prepare_context_parallel_inputs()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
|
|||||||
Reference in New Issue
Block a user