Merge branch 'main' into strip-peft-device-map

This commit is contained in:
Wing Lian
2023-06-12 08:25:54 -04:00
committed by GitHub
30 changed files with 269 additions and 604 deletions

View File

@@ -77,15 +77,9 @@ def load_tokenizer(
def load_model(
base_model,
base_model_config,
model_type,
tokenizer,
cfg,
adapter="lora",
inference=False,
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
@@ -98,7 +92,7 @@ def load_model(
)
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
@@ -305,7 +299,9 @@ def load_model(
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(model)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
model, lora_config = load_adapter(model, cfg, adapter)
@@ -436,6 +432,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=not cfg.inference,
)
else:
model = get_peft_model(model, lora_config)

View File

@@ -57,6 +57,11 @@ def validate_config(cfg):
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25