support llama-adapter zero init attention

This commit is contained in:
Wing Lian
2023-05-01 10:42:21 -04:00
parent 55baef0e03
commit 2255bb7f4f
2 changed files with 54 additions and 25 deletions

View File

@@ -146,8 +146,8 @@ def train(
cfg.bf16 = False
# Load the model and tokenizer
logging.info("loading model, tokenizer, and lora_config...")
model, tokenizer, lora_config = load_model(
logging.info("loading model, tokenizer, and peft_config...")
model, tokenizer, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
@@ -186,9 +186,9 @@ def train(
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if lora_config:
if peft_config:
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
lora_config.save_pretrained(cfg.output_dir)
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:

View File

@@ -195,11 +195,41 @@ def load_adapter(model, cfg, adapter):
return model, None
if adapter == "lora":
return load_lora(model, cfg)
# TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import (
AdaptionPromptConfig,
get_peft_model,
PeftModel,
)
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
adapter_len=cfg.peft_adapter.len, # prompt length (K)
task_type="CAUSAL_LM",
)
if cfg.peft_model_dir:
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
device_map=cfg.device_map,
torch_dtype=torch.float16,
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
def load_lora(model, cfg):
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
@@ -211,27 +241,26 @@ def load_lora(model, cfg):
lora_config = None
if cfg.adapter == "lora":
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
bias="none",
task_type="CAUSAL_LM",
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
bias="none",
task_type="CAUSAL_LM",
)
if cfg.lora_model_dir:
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
device_map=cfg.device_map,
torch_dtype=torch.float16,
)
else:
model = get_peft_model(model, lora_config)
if cfg.lora_model_dir:
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
device_map=cfg.device_map,
torch_dtype=torch.float16,
)
else:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.print_trainable_parameters()
return model, lora_config