From 38773d661fa43164abc7ff7a87e6162c8d7c2326 Mon Sep 17 00:00:00 2001 From: sunny Date: Wed, 30 Oct 2024 11:04:50 -0400 Subject: [PATCH] fixing --- src/axolotl/core/trainer_builder.py | 4 ++-- src/axolotl/monkeypatch/multipack.py | 9 +++++---- src/axolotl/utils/models.py | 9 +++++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d125f838d..e7f34f571 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -895,13 +895,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) - def _save_checkpoint(self, model, trial, metrics=None): + def _save_checkpoint(self, model, trial): # make sure the checkpoint dir exists, since trainer is flakey checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) - return super()._save_checkpoint(model, trial, metrics=metrics) + return super()._save_checkpoint(model, trial) class AxolotlMambaTrainer(AxolotlTrainer): diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index b2ca1a9ab..788b75254 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -28,16 +28,17 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ # def patch_for_multipack(model_type, model_name=None, is_remote_code=False): -def patch_for_multipack(model_type, model_name=None): +def patch_for_multipack(model_type, model_name=None, has_remote_code=False): if model_type == "gemmoe": patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "deepseek_v2": patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") # elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: elif hasattr(transformers, "modeling_flash_attention_utils"): - transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + if not has_remote_code: + transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) if model_type == "mixtral" and is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() return diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 89ca00c9c..64653800f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -393,11 +393,16 @@ class ModelLoader: self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES and self.cfg.flash_attention and self.cfg.sample_packing - ): + ): + has_remote_code = ( + "auto_map" in self.model_config + and self.model_type in self.model_config["auto_map"] + ) + patch_for_multipack( self.cfg.model_config_type, model_name=self.cfg.base_model, - # is_remote_code=self.cfg.trust_remote_code, + has_remote_code=has_remote_code, ) if self.cfg.is_llama_derived_model: