diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 63e34293e..3287c0ee9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,5 +1,4 @@ """Module for models and model loading""" -import importlib import logging import math import os @@ -176,20 +175,6 @@ def load_model( LOG.info("patching _expand_mask") hijack_expand_mask() - # special handling b/c remote MixFormers code doesn't have _no_split_modules set - if ( - "MixFormerSequentialConfig" in model_config.__class__.__name__ - and cfg.model_type == "AutoModelForCausalLM" - ): - module_name = model_config.__class__.__module__.replace( - ".configuration_mixformer_sequential", ".modeling_mixformer_sequential" - ) - modeling_phi = importlib.import_module(module_name) - # pylint:disable=protected-access - modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [ - "ParallelBlock" - ] - model_kwargs = {} if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision