Experimental ReLoRA (+qlora) implementation

This commit is contained in:
Charles Goddard
2023-07-24 09:53:27 -07:00
committed by Wing Lian
parent 918f1b0dfb
commit b57238ecec
6 changed files with 375 additions and 1 deletions

View File

@@ -0,0 +1,65 @@
base_model: /home/charles/.cache/huggingface/hub/models--openlm-research--open_llama_3b/snapshots/8fcddba529aef0eda7293cc9a4171a3994648d2e/
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
prompt_format: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.005
adapter: lora
lora_model_dir:
sequence_len: 512
max_packed_sequence_len: 512
lora_r: 8
lora_alpha: 16
lora_dropout: 0.001
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
relora_steps: 20
relora_warmup_steps: 10
wandb_project: relora
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-out
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
sdp_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -371,8 +371,14 @@ def train(
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.adapter == "lora" and cfg.relora_steps:
model = model.merge_and_unload()
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(train) fire.Fire(train)

View File

@@ -0,0 +1,281 @@
# pylint: skip-file
import glob
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
elif key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.last_full_model = cfg.base_model
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
assert os.path.exists(
self.last_full_model
), "for ReLORA base_model must be a local path"
self.num_lora_restarts = 0
self.need_full_save = False
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
checkpoint_folder,
reinit=True,
quantized=self.quantised,
)
reset_optimizer(optimizer)
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
if state.global_step >= self.relora_steps:
if self.quantised and self.last_full_model != checkpoint_folder:
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
) + glob.glob(f"{self.last_full_model}/model*.index.json")
for path in chunks:
shutil.move(path, checkpoint_folder)
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
return control
def on_log(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
logs: Dict[str, float],
**_kwargs,
):
logs["num_lora_restarts"] = self.num_lora_restarts
return control
class ReLoRAScheduler(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
else:
return original * scale
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
if os.path.exists(index_path):
data = json.load(open(index_path, "r"))
return data["weight_map"]
return {key + ".weight": model_name for key in keys}
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
layer, peft.tuners.lora.Linear4bit
):
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
getattr(layer, "fan_in_fan_out", False),
)
* layer.scaling[adapter]
)
else:
return layer.get_delta_weight()
def merge_and_save(
model: peft.LoraModel,
model_src: str,
model_dst: str,
reinit: bool = False,
quantized: bool = False,
):
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
if not quantized:
for key in key_list:
try:
_parent, target, _target_name = peft.utils._get_submodules(
model.model, key
)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
shard_paths = sharded_paths(model_src, key_list)
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
out_tensors = {}
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]
for key in key_list:
if shard_paths[key + ".weight"] != shard_path:
continue
try:
_parent, target, _target_name = peft.utils._get_submodules(
model.model, key
)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
orig_weight = in_tensors[key + ".weight"]
old_dev = target.weight.device
update = lora_delta_weight(target).detach()
new_weight = (orig_weight.to(old_dev) + update.to(old_dev)).cpu()
out_tensors[key + ".weight"] = new_weight
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
old_dev = target.weight.device
if isinstance(target, peft.tuners.lora.Linear4bit):
target.weight = bnb.nn.Params4bit(
new_weight,
requires_grad=False,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
).to(old_dev)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(
new_weight, requires_grad=False
).to(old_dev)
else:
target.weight.data = new_weight.to(old_dev)
for key in in_tensors:
if key not in out_tensors:
out_tensors[key] = in_tensors[key]
del in_tensors
out_shard_name = shard_path
if out_shard_name.startswith("pytorch_model"):
out_shard_name = (
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
st.save_file(out_tensors, str(Path(model_dst) / out_shard_name))
del out_tensors
torch.cuda.empty_cache()
if len(unique_shards) > 1:
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)

View File

@@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
) )
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(
peft_model_path, save_safetensors=args.save_safetensors
)
return control return control

View File

@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
PrintGPUStatsCallback, PrintGPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
@@ -556,6 +557,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks = [] callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg)) callbacks.append(PrintGPUStatsCallback(cfg))
if cfg.relora_steps:
assert cfg.adapter in (
"lora",
"qlora",
), "Adapter must be lora or qlora to use ReLoRA"
relora_steps = int(cfg.relora_steps)
relora_warmup_steps = int(cfg.relora_warmup_steps)
callbacks.append(ReLoRACallback(cfg))
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
trainer_kwargs["optimizers"] = (
optimizer,
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
)
# TODO on_save callback to sync checkpoints to GCP/AWS in background # TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(

View File

@@ -61,6 +61,9 @@ def validate_config(cfg):
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.trust_remote_code: if cfg.trust_remote_code:
LOG.warning( LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."