make phi training work with Loras (#588)

* valdiation for phi loras

* fix model config class check

* update readme for phi traiing
This commit is contained in:
Wing Lian
2023-09-15 20:51:55 -04:00
committed by GitHub
parent be75668400
commit 62eaee7649
4 changed files with 114 additions and 5 deletions

View File

@@ -1,6 +1,5 @@
"""Module for models and model loading"""
import importlib
import logging
import math
import os
@@ -155,11 +154,26 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_config = load_model_config(cfg)
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
if (
"MixFormerSequentialConfig" in model_config.__class__.__name__
and cfg.model_type == "AutoModelForCausalLM"
):
module_name = model_config.__class__.__module__.replace(
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
)
modeling_phi = importlib.import_module(module_name)
# pylint:disable=protected-access
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
"ParallelBlock"
]
model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
model_config = load_model_config(cfg)
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
else: