remove patch fix for phi (#664)
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
import importlib
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -176,20 +175,6 @@ def load_model(
|
|||||||
LOG.info("patching _expand_mask")
|
LOG.info("patching _expand_mask")
|
||||||
hijack_expand_mask()
|
hijack_expand_mask()
|
||||||
|
|
||||||
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
|
|
||||||
if (
|
|
||||||
"MixFormerSequentialConfig" in model_config.__class__.__name__
|
|
||||||
and cfg.model_type == "AutoModelForCausalLM"
|
|
||||||
):
|
|
||||||
module_name = model_config.__class__.__module__.replace(
|
|
||||||
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
|
|
||||||
)
|
|
||||||
modeling_phi = importlib.import_module(module_name)
|
|
||||||
# pylint:disable=protected-access
|
|
||||||
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
|
|
||||||
"ParallelBlock"
|
|
||||||
]
|
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
|
|||||||
Reference in New Issue
Block a user