remove patch fix for phi (#664)

This commit is contained in:
Wing Lian
2023-10-02 21:07:41 -04:00
committed by GitHub
parent e50a64e85e
commit f34648c8b9

View File

@@ -1,5 +1,4 @@
"""Module for models and model loading"""
import importlib
import logging
import math
import os
@@ -176,20 +175,6 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
if (
"MixFormerSequentialConfig" in model_config.__class__.__name__
and cfg.model_type == "AutoModelForCausalLM"
):
module_name = model_config.__class__.__module__.replace(
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
)
modeling_phi = importlib.import_module(module_name)
# pylint:disable=protected-access
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
"ParallelBlock"
]
model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision