From b57238ececd776b98404d58a92f28f6c8087f49f Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Mon, 24 Jul 2023 09:53:27 -0700 Subject: [PATCH] Experimental ReLoRA (+qlora) implementation --- examples/openllama-3b/relora.yml | 65 +++++++ scripts/finetune.py | 6 + src/axolotl/monkeypatch/relora.py | 281 ++++++++++++++++++++++++++++++ src/axolotl/utils/callbacks.py | 4 +- src/axolotl/utils/trainer.py | 17 ++ src/axolotl/utils/validation.py | 3 + 6 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 examples/openllama-3b/relora.yml create mode 100644 src/axolotl/monkeypatch/relora.py diff --git a/examples/openllama-3b/relora.yml b/examples/openllama-3b/relora.yml new file mode 100644 index 000000000..2d1e5a971 --- /dev/null +++ b/examples/openllama-3b/relora.yml @@ -0,0 +1,65 @@ +base_model: /home/charles/.cache/huggingface/hub/models--openlm-research--open_llama_3b/snapshots/8fcddba529aef0eda7293cc9a4171a3994648d2e/ +base_model_config: openlm-research/open_llama_3b +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +load_in_8bit: false +load_in_4bit: false +strict: false +push_dataset_to_hub: +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +prompt_format: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.005 +adapter: lora +lora_model_dir: +sequence_len: 512 +max_packed_sequence_len: 512 +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.001 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: +relora_steps: 20 +relora_warmup_steps: 10 +wandb_project: relora +wandb_watch: +wandb_run_id: +wandb_log_model: +output_dir: ./lora-out +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +torchdistx_path: +lr_scheduler: cosine +learning_rate: 0.0002 +train_on_inputs: false +group_by_length: false +bf16: false +fp16: true +tf32: false +gradient_checkpointing: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +sdp_attention: true +flash_attention: +gptq_groupsize: +gptq_model_v1: +warmup_steps: 10 +eval_steps: 20 +save_steps: 50 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/scripts/finetune.py b/scripts/finetune.py index a7fee5ec8..d43377c79 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -371,8 +371,14 @@ def train( elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) + + if cfg.adapter == "lora" and cfg.relora_steps: + model = model.merge_and_unload() + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time + if __name__ == "__main__": fire.Fire(train) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py new file mode 100644 index 000000000..cee90c48f --- /dev/null +++ b/src/axolotl/monkeypatch/relora.py @@ -0,0 +1,281 @@ +# pylint: skip-file +import glob +import json +import logging +import os.path +import shutil +from pathlib import Path +from typing import Dict, List, Sequence + +import bitsandbytes as bnb +import peft +import safetensors.torch as st +import torch +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.relora") + + +def reset_optimizer(optimizer: torch.optim.Optimizer): + for group in optimizer.param_groups: + for param in group["params"]: + param_state = optimizer.state[param] + for key in param_state: + if "qmap" in key: + continue + elif key == "step" and isinstance(param_state[key], int): + param_state[key] = 0 + else: + param_state[key] = torch.zeros_like(param_state[key]) + + +class ReLoRACallback(TrainerCallback): + def __init__(self, cfg: DictDefault): + self.relora_steps = cfg.relora_steps + self.last_full_model = cfg.base_model + self.quantised = cfg.load_in_4bit or cfg.load_in_8bit + + assert os.path.exists( + self.last_full_model + ), "for ReLORA base_model must be a local path" + + self.num_lora_restarts = 0 + self.need_full_save = False + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model: peft.LoraModel, + optimizer: torch.optim.Optimizer, + **_kwargs, + ): + if state.global_step > 0 and state.global_step % self.relora_steps == 0: + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + + with torch.no_grad(): + merge_and_save( + model, + self.last_full_model, + checkpoint_folder, + reinit=True, + quantized=self.quantised, + ) + reset_optimizer(optimizer) + + self.last_full_model = checkpoint_folder + self.num_lora_restarts += 1 + + return control + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model: peft.LoraModel, + **kwargs, + ): + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + if state.global_step >= self.relora_steps: + if self.quantised and self.last_full_model != checkpoint_folder: + LOG.info(f"moving last full parameter save to {checkpoint_folder}") + chunks = glob.glob( + f"{self.last_full_model}/model*.safetensors" + ) + glob.glob(f"{self.last_full_model}/model*.index.json") + for path in chunks: + shutil.move(path, checkpoint_folder) + self.last_full_model = checkpoint_folder + else: + model.model.save_pretrained(checkpoint_folder, save_safetensors=True) + + return control + + def on_log( + self, + _args: TrainingArguments, + _state: TrainerState, + control: TrainerControl, + logs: Dict[str, float], + **_kwargs, + ): + logs["num_lora_restarts"] = self.num_lora_restarts + return control + + +class ReLoRAScheduler(LRScheduler): + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + relora_steps: int, + warmup_steps: int, + min_lr_scale: float = 0.001, + ) -> None: + self.inner_schedule = inner_schedule + self.relora_steps = relora_steps + self.warmup_steps = warmup_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) + + def get_lr(self) -> float: + self.inner_schedule.last_epoch = self.last_epoch + + original = self.inner_schedule.get_lr() + step = self.last_epoch + if step < self.relora_steps: + scale = 1 + else: + cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps) + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + if isinstance(original, Sequence): + return [lr * scale for lr in original] + else: + return original * scale + + +def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]: + model_name = "model.safetensors" + if not os.path.exists(str(Path(path) / model_name)): + model_name = "pytorch_model.bin" + + index_path = str(Path(path) / f"{model_name}.index.json") + if os.path.exists(index_path): + data = json.load(open(index_path, "r")) + return data["weight_map"] + return {key + ".weight": model_name for key in keys} + + +def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor: + if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance( + layer, peft.tuners.lora.Linear4bit + ): + adapter = layer.active_adapter + return ( + peft.utils.transpose( + layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight, + getattr(layer, "fan_in_fan_out", False), + ) + * layer.scaling[adapter] + ) + else: + return layer.get_delta_weight() + + +def merge_and_save( + model: peft.LoraModel, + model_src: str, + model_dst: str, + reinit: bool = False, + quantized: bool = False, +): + key_list = [key for key, _ in model.model.named_modules() if "lora" not in key] + + if not quantized: + for key in key_list: + try: + _parent, target, _target_name = peft.utils._get_submodules( + model.model, key + ) + except AttributeError: + continue + + if isinstance(target, peft.tuners.lora.LoraLayer): + update = target.get_delta_weight(target.active_adapter).detach() + target.weight.data += update + + if reinit: + for adapter_name in target.lora_A: + target.reset_lora_parameters(adapter_name) + for adapter_name in target.lora_embedding_A: + target.reset_lora_parameters(adapter_name) + return + + os.makedirs(model_dst, exist_ok=True) + shard_paths = sharded_paths(model_src, key_list) + + unique_shards = list(set(shard_paths.values())) + for shard_path in unique_shards: + out_tensors = {} + if shard_path.endswith(".safetensors"): + in_tensors = st.load_file(str(Path(model_src) / shard_path)) + else: + in_tensors = torch.load(Path(model_src) / shard_path) + if "state_dict" in in_tensors: + in_tensors = in_tensors["state_dict"] + + for key in key_list: + if shard_paths[key + ".weight"] != shard_path: + continue + + try: + _parent, target, _target_name = peft.utils._get_submodules( + model.model, key + ) + except AttributeError: + continue + + if isinstance(target, peft.tuners.lora.LoraLayer): + orig_weight = in_tensors[key + ".weight"] + old_dev = target.weight.device + + update = lora_delta_weight(target).detach() + new_weight = (orig_weight.to(old_dev) + update.to(old_dev)).cpu() + out_tensors[key + ".weight"] = new_weight + + if reinit: + for adapter_name in target.lora_A: + target.reset_lora_parameters(adapter_name) + for adapter_name in target.lora_embedding_A: + target.reset_lora_parameters(adapter_name) + + old_dev = target.weight.device + if isinstance(target, peft.tuners.lora.Linear4bit): + target.weight = bnb.nn.Params4bit( + new_weight, + requires_grad=False, + compress_statistics=target.weight.compress_statistics, + quant_type=target.weight.quant_type, + ).to(old_dev) + elif isinstance(target, peft.tuners.lora.Linear8bitLt): + target.weight = bnb.nn.Int8Params( + new_weight, requires_grad=False + ).to(old_dev) + else: + target.weight.data = new_weight.to(old_dev) + + for key in in_tensors: + if key not in out_tensors: + out_tensors[key] = in_tensors[key] + del in_tensors + + out_shard_name = shard_path + if out_shard_name.startswith("pytorch_model"): + out_shard_name = ( + out_shard_name.replace("pytorch_model", "model").rstrip(".bin") + + ".safetensors" + ) + st.save_file(out_tensors, str(Path(model_dst) / out_shard_name)) + del out_tensors + torch.cuda.empty_cache() + + if len(unique_shards) > 1: + with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd: + json.dump({"metadata": {}, "weight_map": shard_paths}, fd) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f06762b6b..f0a93e5a5 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") - kwargs["model"].save_pretrained(peft_model_path) + kwargs["model"].save_pretrained( + peft_model_path, save_safetensors=args.save_safetensors + ) return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 25d0b1e82..269c7696c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -21,6 +21,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names +from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( PrintGPUStatsCallback, SaveBetterTransformerModelCallback, @@ -556,6 +557,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ callbacks = [] callbacks.append(PrintGPUStatsCallback(cfg)) + + if cfg.relora_steps: + assert cfg.adapter in ( + "lora", + "qlora", + ), "Adapter must be lora or qlora to use ReLoRA" + relora_steps = int(cfg.relora_steps) + relora_warmup_steps = int(cfg.relora_warmup_steps) + callbacks.append(ReLoRACallback(cfg)) + + (optimizer, lr_scheduler) = trainer_kwargs["optimizers"] + trainer_kwargs["optimizers"] = ( + optimizer, + ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps), + ) + # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 97d70c4c8..8babdc262 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -61,6 +61,9 @@ def validate_config(cfg): if not cfg.load_in_8bit and cfg.adapter == "lora": LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") + if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"): + raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + if cfg.trust_remote_code: LOG.warning( "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."