make phi training work with Loras (#588)

* valdiation for phi loras

* fix model config class check

* update readme for phi traiing
This commit is contained in:
Wing Lian
2023-09-15 20:51:55 -04:00
committed by GitHub
parent be75668400
commit 62eaee7649
4 changed files with 114 additions and 5 deletions

View File

@@ -75,6 +75,7 @@ def normalize_config(cfg):
cfg.torch_dtype = torch.float32
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
@@ -237,6 +238,21 @@ def validate_config(cfg):
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
if cfg.model_type == "MixFormerSequentialForCausalLM" and cfg.adapter is not None:
LOG.warning("Use AutoModelForCausalLM for phi/MixFormer models with qLoRA")
if cfg.model_config_type == "mixformer-sequential":
if cfg.sample_packing:
if cfg.adapter is not None:
LOG.warning(
"phi/MixFormer models are not currently compatible with LoRA and sample_packing"
)
if cfg.model_type == "AutoModelForCausalLM":
raise ValueError(
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -1,6 +1,5 @@
"""Module for models and model loading"""
import importlib
import logging
import math
import os
@@ -155,11 +154,26 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_config = load_model_config(cfg)
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
if (
"MixFormerSequentialConfig" in model_config.__class__.__name__
and cfg.model_type == "AutoModelForCausalLM"
):
module_name = model_config.__class__.__module__.replace(
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
)
modeling_phi = importlib.import_module(module_name)
# pylint:disable=protected-access
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
"ParallelBlock"
]
model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
model_config = load_model_config(cfg)
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
else: