diff --git a/scripts/finetune.py b/scripts/finetune.py index cd8c6f650..cf740e00e 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -146,8 +146,8 @@ def train( cfg.bf16 = False # Load the model and tokenizer - logging.info("loading model, tokenizer, and lora_config...") - model, tokenizer, lora_config = load_model( + logging.info("loading model, tokenizer, and peft_config...") + model, tokenizer, peft_config = load_model( cfg.base_model, cfg.base_model_config, cfg.model_type, @@ -186,9 +186,9 @@ def train( model = torch.compile(model) # go ahead and presave, so we have the adapter config available to inspect - if lora_config: + if peft_config: logging.info(f"Pre-saving adapter config to {cfg.output_dir}") - lora_config.save_pretrained(cfg.output_dir) + peft_config.save_pretrained(cfg.output_dir) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bd73fc76a..8b2b7ad6a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -195,11 +195,41 @@ def load_adapter(model, cfg, adapter): return model, None if adapter == "lora": return load_lora(model, cfg) - # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls + if adapter == "llama-adapter": + return load_llama_adapter(model, cfg) raise NotImplementedError(f"{adapter} peft adapter not available") +def load_llama_adapter(model, cfg): + # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + from peft import ( + AdaptionPromptConfig, + get_peft_model, + PeftModel, + ) + + peft_config = AdaptionPromptConfig( + adapter_layers=cfg.peft_adapter.layers, # layers (L) + adapter_len=cfg.peft_adapter.len, # prompt length (K) + task_type="CAUSAL_LM", + ) + + if cfg.peft_model_dir: + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + device_map=cfg.device_map, + torch_dtype=torch.float16, + ) + else: + model = get_peft_model(model, peft_config) + + model.print_trainable_parameters() + + return model, peft_config + + def load_lora(model, cfg): # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] @@ -211,27 +241,26 @@ def load_lora(model, cfg): lora_config = None - if cfg.adapter == "lora": - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=cfg.lora_target_modules, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - bias="none", - task_type="CAUSAL_LM", + lora_config = LoraConfig( + r=cfg.lora_r, + lora_alpha=cfg.lora_alpha, + target_modules=cfg.lora_target_modules, + lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, + bias="none", + task_type="CAUSAL_LM", + ) + + if cfg.lora_model_dir: + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + device_map=cfg.device_map, + torch_dtype=torch.float16, ) + else: + model = get_peft_model(model, lora_config) - if cfg.lora_model_dir: - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - device_map=cfg.device_map, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, lora_config) - - model.print_trainable_parameters() + model.print_trainable_parameters() return model, lora_config