diff --git a/examples/openllama-3b/relora.yml b/examples/openllama-3b/relora.yml
new file mode 100644
index 000000000..2d1e5a971
--- /dev/null
+++ b/examples/openllama-3b/relora.yml
@@ -0,0 +1,65 @@
+base_model: /home/charles/.cache/huggingface/hub/models--openlm-research--open_llama_3b/snapshots/8fcddba529aef0eda7293cc9a4171a3994648d2e/
+base_model_config: openlm-research/open_llama_3b
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+push_dataset_to_hub:
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+prompt_format: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.005
+adapter: lora
+lora_model_dir:
+sequence_len: 512
+max_packed_sequence_len: 512
+lora_r: 8
+lora_alpha: 16
+lora_dropout: 0.001
+lora_target_modules:
+lora_target_linear: true
+lora_fan_in_fan_out:
+relora_steps: 20
+relora_warmup_steps: 10
+wandb_project: relora
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+output_dir: ./lora-out
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+torchdistx_path:
+lr_scheduler: cosine
+learning_rate: 0.0002
+train_on_inputs: false
+group_by_length: false
+bf16: false
+fp16: true
+tf32: false
+gradient_checkpointing: false
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+sdp_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 10
+eval_steps: 20
+save_steps: 50
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/scripts/finetune.py b/scripts/finetune.py
index a7fee5ec8..d43377c79 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -371,8 +371,14 @@ def train(
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
+
+ if cfg.adapter == "lora" and cfg.relora_steps:
+ model = model.merge_and_unload()
+
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
+ # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
+
if __name__ == "__main__":
fire.Fire(train)
diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py
new file mode 100644
index 000000000..cee90c48f
--- /dev/null
+++ b/src/axolotl/monkeypatch/relora.py
@@ -0,0 +1,281 @@
+# pylint: skip-file
+import glob
+import json
+import logging
+import os.path
+import shutil
+from pathlib import Path
+from typing import Dict, List, Sequence
+
+import bitsandbytes as bnb
+import peft
+import safetensors.torch as st
+import torch
+from torch.optim.lr_scheduler import LRScheduler
+from torch.optim.optimizer import Optimizer
+from transformers import (
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+ TrainingArguments,
+)
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+from axolotl.utils.dict import DictDefault
+
+LOG = logging.getLogger("axolotl.relora")
+
+
+def reset_optimizer(optimizer: torch.optim.Optimizer):
+ for group in optimizer.param_groups:
+ for param in group["params"]:
+ param_state = optimizer.state[param]
+ for key in param_state:
+ if "qmap" in key:
+ continue
+ elif key == "step" and isinstance(param_state[key], int):
+ param_state[key] = 0
+ else:
+ param_state[key] = torch.zeros_like(param_state[key])
+
+
+class ReLoRACallback(TrainerCallback):
+ def __init__(self, cfg: DictDefault):
+ self.relora_steps = cfg.relora_steps
+ self.last_full_model = cfg.base_model
+ self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
+
+ assert os.path.exists(
+ self.last_full_model
+ ), "for ReLORA base_model must be a local path"
+
+ self.num_lora_restarts = 0
+ self.need_full_save = False
+
+ def on_step_begin(
+ self,
+ args: TrainingArguments,
+ state: TrainerState,
+ control: TrainerControl,
+ model: peft.LoraModel,
+ optimizer: torch.optim.Optimizer,
+ **_kwargs,
+ ):
+ if state.global_step > 0 and state.global_step % self.relora_steps == 0:
+ checkpoint_folder = os.path.join(
+ args.output_dir,
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
+ )
+
+ with torch.no_grad():
+ merge_and_save(
+ model,
+ self.last_full_model,
+ checkpoint_folder,
+ reinit=True,
+ quantized=self.quantised,
+ )
+ reset_optimizer(optimizer)
+
+ self.last_full_model = checkpoint_folder
+ self.num_lora_restarts += 1
+
+ return control
+
+ def on_save(
+ self,
+ args: TrainingArguments,
+ state: TrainerState,
+ control: TrainerControl,
+ model: peft.LoraModel,
+ **kwargs,
+ ):
+ checkpoint_folder = os.path.join(
+ args.output_dir,
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
+ )
+ if state.global_step >= self.relora_steps:
+ if self.quantised and self.last_full_model != checkpoint_folder:
+ LOG.info(f"moving last full parameter save to {checkpoint_folder}")
+ chunks = glob.glob(
+ f"{self.last_full_model}/model*.safetensors"
+ ) + glob.glob(f"{self.last_full_model}/model*.index.json")
+ for path in chunks:
+ shutil.move(path, checkpoint_folder)
+ self.last_full_model = checkpoint_folder
+ else:
+ model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
+
+ return control
+
+ def on_log(
+ self,
+ _args: TrainingArguments,
+ _state: TrainerState,
+ control: TrainerControl,
+ logs: Dict[str, float],
+ **_kwargs,
+ ):
+ logs["num_lora_restarts"] = self.num_lora_restarts
+ return control
+
+
+class ReLoRAScheduler(LRScheduler):
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ inner_schedule: LRScheduler,
+ relora_steps: int,
+ warmup_steps: int,
+ min_lr_scale: float = 0.001,
+ ) -> None:
+ self.inner_schedule = inner_schedule
+ self.relora_steps = relora_steps
+ self.warmup_steps = warmup_steps
+ self.min_lr_scale = min_lr_scale
+ super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
+
+ def get_lr(self) -> float:
+ self.inner_schedule.last_epoch = self.last_epoch
+
+ original = self.inner_schedule.get_lr()
+ step = self.last_epoch
+ if step < self.relora_steps:
+ scale = 1
+ else:
+ cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
+ scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
+ if isinstance(original, Sequence):
+ return [lr * scale for lr in original]
+ else:
+ return original * scale
+
+
+def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
+ model_name = "model.safetensors"
+ if not os.path.exists(str(Path(path) / model_name)):
+ model_name = "pytorch_model.bin"
+
+ index_path = str(Path(path) / f"{model_name}.index.json")
+ if os.path.exists(index_path):
+ data = json.load(open(index_path, "r"))
+ return data["weight_map"]
+ return {key + ".weight": model_name for key in keys}
+
+
+def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
+ if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
+ layer, peft.tuners.lora.Linear4bit
+ ):
+ adapter = layer.active_adapter
+ return (
+ peft.utils.transpose(
+ layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
+ getattr(layer, "fan_in_fan_out", False),
+ )
+ * layer.scaling[adapter]
+ )
+ else:
+ return layer.get_delta_weight()
+
+
+def merge_and_save(
+ model: peft.LoraModel,
+ model_src: str,
+ model_dst: str,
+ reinit: bool = False,
+ quantized: bool = False,
+):
+ key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
+
+ if not quantized:
+ for key in key_list:
+ try:
+ _parent, target, _target_name = peft.utils._get_submodules(
+ model.model, key
+ )
+ except AttributeError:
+ continue
+
+ if isinstance(target, peft.tuners.lora.LoraLayer):
+ update = target.get_delta_weight(target.active_adapter).detach()
+ target.weight.data += update
+
+ if reinit:
+ for adapter_name in target.lora_A:
+ target.reset_lora_parameters(adapter_name)
+ for adapter_name in target.lora_embedding_A:
+ target.reset_lora_parameters(adapter_name)
+ return
+
+ os.makedirs(model_dst, exist_ok=True)
+ shard_paths = sharded_paths(model_src, key_list)
+
+ unique_shards = list(set(shard_paths.values()))
+ for shard_path in unique_shards:
+ out_tensors = {}
+ if shard_path.endswith(".safetensors"):
+ in_tensors = st.load_file(str(Path(model_src) / shard_path))
+ else:
+ in_tensors = torch.load(Path(model_src) / shard_path)
+ if "state_dict" in in_tensors:
+ in_tensors = in_tensors["state_dict"]
+
+ for key in key_list:
+ if shard_paths[key + ".weight"] != shard_path:
+ continue
+
+ try:
+ _parent, target, _target_name = peft.utils._get_submodules(
+ model.model, key
+ )
+ except AttributeError:
+ continue
+
+ if isinstance(target, peft.tuners.lora.LoraLayer):
+ orig_weight = in_tensors[key + ".weight"]
+ old_dev = target.weight.device
+
+ update = lora_delta_weight(target).detach()
+ new_weight = (orig_weight.to(old_dev) + update.to(old_dev)).cpu()
+ out_tensors[key + ".weight"] = new_weight
+
+ if reinit:
+ for adapter_name in target.lora_A:
+ target.reset_lora_parameters(adapter_name)
+ for adapter_name in target.lora_embedding_A:
+ target.reset_lora_parameters(adapter_name)
+
+ old_dev = target.weight.device
+ if isinstance(target, peft.tuners.lora.Linear4bit):
+ target.weight = bnb.nn.Params4bit(
+ new_weight,
+ requires_grad=False,
+ compress_statistics=target.weight.compress_statistics,
+ quant_type=target.weight.quant_type,
+ ).to(old_dev)
+ elif isinstance(target, peft.tuners.lora.Linear8bitLt):
+ target.weight = bnb.nn.Int8Params(
+ new_weight, requires_grad=False
+ ).to(old_dev)
+ else:
+ target.weight.data = new_weight.to(old_dev)
+
+ for key in in_tensors:
+ if key not in out_tensors:
+ out_tensors[key] = in_tensors[key]
+ del in_tensors
+
+ out_shard_name = shard_path
+ if out_shard_name.startswith("pytorch_model"):
+ out_shard_name = (
+ out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ + ".safetensors"
+ )
+ st.save_file(out_tensors, str(Path(model_dst) / out_shard_name))
+ del out_tensors
+ torch.cuda.empty_cache()
+
+ if len(unique_shards) > 1:
+ with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
+ json.dump({"metadata": {}, "weight_map": shard_paths}, fd)
diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py
index f06762b6b..f0a93e5a5 100644
--- a/src/axolotl/utils/callbacks.py
+++ b/src/axolotl/utils/callbacks.py
@@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
- kwargs["model"].save_pretrained(peft_model_path)
+ kwargs["model"].save_pretrained(
+ peft_model_path, save_safetensors=args.save_safetensors
+ )
return control
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 25d0b1e82..269c7696c 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
+from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
SaveBetterTransformerModelCallback,
@@ -556,6 +557,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
+
+ if cfg.relora_steps:
+ assert cfg.adapter in (
+ "lora",
+ "qlora",
+ ), "Adapter must be lora or qlora to use ReLoRA"
+ relora_steps = int(cfg.relora_steps)
+ relora_warmup_steps = int(cfg.relora_warmup_steps)
+ callbacks.append(ReLoRACallback(cfg))
+
+ (optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
+ trainer_kwargs["optimizers"] = (
+ optimizer,
+ ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
+ )
+
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py
index 97d70c4c8..8babdc262 100644
--- a/src/axolotl/utils/validation.py
+++ b/src/axolotl/utils/validation.py
@@ -61,6 +61,9 @@ def validate_config(cfg):
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
+ if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
+ raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
+
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."