diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 85101cd3c..b2ca1a9ab 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -27,12 +27,14 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ ] -def patch_for_multipack(model_type, model_name=None, is_remote_code=False): +# def patch_for_multipack(model_type, model_name=None, is_remote_code=False): +def patch_for_multipack(model_type, model_name=None): if model_type == "gemmoe": patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "deepseek_v2": patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") - elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: + # elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: + elif hasattr(transformers, "modeling_flash_attention_utils"): transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 30f2904a6..f3386cccf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -393,8 +393,7 @@ class ModelLoader: self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES and self.cfg.flash_attention and self.cfg.sample_packing - ): - LOG.info(f"Model_config_type: {self.cfg.model_config_type}") + ): patch_for_multipack( self.cfg.model_config_type, model_name=self.cfg.base_model,