diff --git a/README.md b/README.md index 4ec79729f..a81ac8b50 100644 --- a/README.md +++ b/README.md @@ -493,6 +493,12 @@ lora_modules_to_save: lora_out_dir: lora_fan_in_fan_out: false +# ReLoRA configuration +# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed +relora_steps: # number of steps per ReLoRA restart +relora_warmup_steps: # number of per-restart warmup steps +relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings + # wandb configuration if you're using it wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_project: # your wandb project name diff --git a/scripts/finetune.py b/scripts/finetune.py index 78eca05b9..3255a623f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -242,6 +242,21 @@ def train( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) return + if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: + possible_checkpoints = [ + str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") + ] + if len(possible_checkpoints) > 0: + sorted_paths = sorted( + possible_checkpoints, + key=lambda path: int(path.split("-")[-1]), + ) + cfg.resume_from_checkpoint = sorted_paths[-1] + LOG.info( + f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" + ) + resume_from_checkpoint = cfg.resume_from_checkpoint + trainer = setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps ) @@ -273,20 +288,6 @@ def train( LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") - resume_from_checkpoint = cfg.resume_from_checkpoint - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" - ) if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) @@ -301,6 +302,13 @@ def train( LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + if cfg.relora_steps: + if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): + model = model.merge_and_unload() + else: + # final model weights have already been saved by `ReLoRACallback.on_train_end` + return + # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: @@ -308,6 +316,7 @@ def train( elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py new file mode 100644 index 000000000..e247fafd2 --- /dev/null +++ b/src/axolotl/monkeypatch/relora.py @@ -0,0 +1,393 @@ +"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" +import glob +import json +import logging +import os.path +import shutil +from pathlib import Path +from typing import Dict, List, Sequence + +import bitsandbytes as bnb +import peft +import safetensors.torch as st +import torch +from huggingface_hub import snapshot_download +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import is_main_process + +LOG = logging.getLogger("axolotl.relora") + + +def reset_optimizer(optimizer: torch.optim.Optimizer): + for group in optimizer.param_groups: + for param in group["params"]: + param_state = optimizer.state[param] + for key in param_state: + if "qmap" in key: + continue + + if key == "step" and isinstance(param_state[key], int): + param_state[key] = 0 + else: + param_state[key] = torch.zeros_like(param_state[key]) + + +class ReLoRACallback(TrainerCallback): + """Callback to merge LoRA weights into the base model and save full-weight checkpoints""" + + def __init__(self, cfg: DictDefault): + self.relora_steps = cfg.relora_steps + self.cpu_offload = cfg.relora_cpu_offload + self.quantized = cfg.load_in_4bit or cfg.load_in_8bit + self.last_full_model = cfg.base_model + self.resume_from_checkpoint = cfg.resume_from_checkpoint + + if not os.path.exists(self.last_full_model): + self.last_full_model = str(Path(snapshot_download(cfg.base_model))) + + assert os.path.exists( + self.last_full_model + ), "for ReLORA base_model must be a local path" + + self.num_lora_restarts = 0 + self.need_full_save = False + + def on_train_begin( + self, + _args: TrainingArguments, + _state: TrainerState, + control: TrainerControl, + model: peft.LoraModel, + **_kwargs, + ): + if self.resume_from_checkpoint: + weight_path = os.path.join(self.resume_from_checkpoint, "relora") + if not os.path.exists(weight_path): + LOG.warning( + "Resuming ReLoRA from checkpoint, but no full-weight save found" + ) + else: + LOG.info(f"Loading adjusted base weights from {weight_path}") + load_weight_checkpoint(model, weight_path) + return control + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model: peft.LoraModel, + optimizer: torch.optim.Optimizer, + **_kwargs, + ): + if state.global_step > 0 and state.global_step % self.relora_steps == 0: + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + "relora", + ) + + with torch.no_grad(): + merge_and_save( + model, + self.last_full_model, + checkpoint_folder, + reinit=True, + quantized=self.quantized, + actually_save=is_main_process(), + cpu_offload=self.cpu_offload, + ) + reset_optimizer(optimizer) + + if self.quantized: + self.last_full_model = checkpoint_folder + self.num_lora_restarts += 1 + + return control + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model: peft.LoraModel, + **_kwargs, + ): + checkpoint_folder = os.path.join( + args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora" + ) + if ( + state.global_step >= self.relora_steps + and state.global_step % self.relora_steps != 0 + ): + if self.quantized: + if self.last_full_model != checkpoint_folder: + # ensure the latest full parameter save is in the latest checkpoint + # folder, so that automatic pruning of checkpoints does not remove it + LOG.info(f"moving last full parameter save to {checkpoint_folder}") + os.makedirs(checkpoint_folder, exist_ok=True) + chunks = glob.glob( + f"{self.last_full_model}/model*.safetensors" + ) + glob.glob(f"{self.last_full_model}/model*.index.json") + for path in chunks: + new_path = os.path.abspath(shutil.move(path, checkpoint_folder)) + try: + os.symlink(new_path, path) + except OSError: + # probably on windows without permission to symlink + pass + + self.last_full_model = checkpoint_folder + else: + model.model.save_pretrained(checkpoint_folder, safe_serialization=True) + + return control + + def on_log( + self, + _args: TrainingArguments, + _state: TrainerState, + control: TrainerControl, + logs: Dict[str, float], + **_kwargs, + ): + logs["num_lora_restarts"] = self.num_lora_restarts + return control + + def on_train_end( + self, + args: TrainingArguments, + _state: TrainerState, + control: TrainerControl, + model: peft.LoraModel, + **_kwargs, + ): + if self.quantized: + # perform final merge and save + with torch.no_grad(): + merge_and_save( + model, + self.last_full_model, + args.output_dir, + reinit=False, + quantized=self.quantized, + actually_save=is_main_process(), + cpu_offload=self.cpu_offload, + ) + # no need to save if unquantized, as finetune.py will call merge_and_unload() + return control + + +class ReLoRAScheduler(LRScheduler): + """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" + + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + relora_steps: int, + warmup_steps: int, + min_lr_scale: float = 0.001, + ) -> None: + self.inner_schedule = inner_schedule + self.relora_steps = relora_steps + self.warmup_steps = warmup_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) + + def get_lr(self) -> float: + self.inner_schedule.last_epoch = self.last_epoch + + original = self.inner_schedule.get_lr() + step = self.last_epoch + if step < self.relora_steps: + scale = 1 + else: + cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps) + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + + if isinstance(original, Sequence): + return [lr * scale for lr in original] + return original * scale + + +def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: + model_name = "model.safetensors" + if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists( + str(Path(path) / f"{model_name}.index.json") + ): + model_name = "pytorch_model.bin" + + index_path = str(Path(path) / f"{model_name}.index.json") + if os.path.exists(index_path): + with open(index_path, "r", encoding="utf-8") as file: + data = json.load(file) + return data["weight_map"] + return {(module_name + ".weight"): model_name for module_name in module_names} + + +def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor: + if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): + adapter = layer.active_adapter + return ( + peft.utils.transpose( + layer.lora_B[adapter].weight.detach().to(device) + @ layer.lora_A[adapter].weight.detach().to(device), + getattr(layer, "fan_in_fan_out", False), + ) + * layer.scaling[adapter] + ) + + return layer.get_delta_weight().to(device) + + +def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]: + modules: Dict[str, peft.tuners.lora.LoraLayer] = {} + + key_list = [key for key, _ in model.model.named_modules() if "lora" not in key] + for key in key_list: + try: + # pylint: disable=protected-access + _parent, target, _target_name = peft.utils._get_submodules(model.model, key) + except AttributeError: + continue + + if isinstance(target, peft.tuners.lora.LoraLayer): + modules[key] = target + + return modules + + +def update_weights( + target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device +): + if reinit: + for adapter_name in target.lora_A: + target.reset_lora_parameters(adapter_name) + for adapter_name in target.lora_embedding_A: + target.reset_lora_parameters(adapter_name) + + if isinstance(target, peft.tuners.lora.Linear4bit): + # This could be faster, but the quantization of Linear4bit weights occurs + # when the module is moved from cpu to gpu. Without meddling *too* deeply in + # PEFT's innards or maintaining a duplicate of that codepath, this is good + # enough for now. + target.weight.quant_state = None + target.weight.data = new_weight.cpu() + target.to(device) + elif isinstance(target, peft.tuners.lora.Linear8bitLt): + target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device) + else: + target.weight.data = new_weight.to(device) + + +def merge_and_save( + model: peft.LoraModel, + model_src: str, + model_dst: str, + reinit: bool = False, + quantized: bool = False, + cpu_offload: bool = False, + actually_save: bool = True, +): + modules = find_lora_modules(model) + + if not quantized: + for module_name, target in modules.items(): + update = target.get_delta_weight(target.active_adapter).detach() + target.weight.data += update + + if reinit: + for adapter_name in target.lora_A: + target.reset_lora_parameters(adapter_name) + for adapter_name in target.lora_embedding_A: + target.reset_lora_parameters(adapter_name) + return + + os.makedirs(model_dst, exist_ok=True) + shard_paths = sharded_paths(model_src, modules.keys()) + out_shard_paths = {} + + unique_shards = list(set(shard_paths.values())) + for shard_path in unique_shards: + out_tensors = {} + if shard_path.endswith(".safetensors"): + in_tensors = st.load_file(str(Path(model_src) / shard_path)) + else: + in_tensors = torch.load(Path(model_src) / shard_path) + if "state_dict" in in_tensors: + in_tensors = in_tensors["state_dict"] + + for module_name, target in modules.items(): + key = module_name + ".weight" + if key not in shard_paths or shard_paths[key] != shard_path: + continue + + orig_weight = in_tensors[key] + old_dev = target.weight.device + math_dev = "cpu" if cpu_offload else old_dev + + delta_weight = lora_delta_weight(target, math_dev) + new_weight = orig_weight.to(math_dev) + delta_weight + del delta_weight + + if actually_save: + out_tensors[key] = new_weight.half().cpu() + + update_weights(target, new_weight, reinit=reinit, device=old_dev) + + if actually_save: + out_shard_name = shard_path + if out_shard_name.startswith("pytorch_model"): + out_shard_name = ( + out_shard_name.replace("pytorch_model", "model").rstrip(".bin") + + ".safetensors" + ) + + for module_name in in_tensors: + if module_name not in out_tensors: + out_tensors[module_name] = in_tensors[module_name].half() + out_shard_paths[module_name] = out_shard_name + + shard_fn = str(Path(model_dst) / out_shard_name) + LOG.info(f"saving tensors to {shard_fn}") + st.save_file(out_tensors, shard_fn, metadata={"format": "pt"}) + + del in_tensors + del out_tensors + torch.cuda.empty_cache() + + if actually_save and len(unique_shards) > 1: + with open( + str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8" + ) as file: + json.dump({"metadata": {}, "weight_map": out_shard_paths}, file) + + +def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str): + modules = find_lora_modules(model) + shard_paths = sharded_paths(checkpoint_path, modules.keys()) + unique_shards = list(set(shard_paths.values())) + + for shard_path in unique_shards: + tensors = st.load_file(os.path.join(checkpoint_path, shard_path)) + + for module_name, target in modules.items(): + key = module_name + ".weight" + if key not in shard_paths or shard_paths[key] != shard_path: + continue + + new_weight = tensors[key] + update_weights( + target, new_weight, reinit=False, device=target.weight.device + ) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 32a7f0c99..ddc179f39 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") - kwargs["model"].save_pretrained(peft_model_path) + kwargs["model"].save_pretrained( + peft_model_path, save_safetensors=args.save_safetensors + ) return control diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 3d2e029c3..abb3154d2 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -126,6 +126,19 @@ def validate_config(cfg): if not cfg.load_in_8bit and cfg.adapter == "lora": LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") + if cfg.relora_steps: + if cfg.adapter not in ("lora", "qlora"): + raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + + if cfg.fsdp: + raise ValueError("fsdp not supported with ReLoRA") + + if cfg.deepspeed: + raise ValueError("deepspeed not supported with ReLoRA") + + if cfg.lr_scheduler == "one_cycle": + raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") + if cfg.trust_remote_code: LOG.warning( "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c9c17fe33..c73b4a713 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -24,6 +24,7 @@ from transformers.trainer_pt_utils import ( get_parameter_names, ) +from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( GPUStatsCallback, SaveBetterTransformerModelCallback, @@ -127,6 +128,14 @@ class AxolotlTrainingArguments(TrainingArguments): default=1, metadata={"help": "the multiplier for the max len for packed sequences"}, ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) class AxolotlTrainer(Trainer): @@ -265,6 +274,39 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): return self.lr_scheduler +class ReLoRATrainer(AxolotlTrainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + + if self.args.relora_steps: + warmup_steps = ( + self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 + ) + self.lr_scheduler = ReLoRAScheduler( + optimizer, + lr_scheduler, + self.args.relora_steps, + warmup_steps, + ) + else: + self.lr_scheduler = lr_scheduler + + return self.lr_scheduler + + def add_position_ids(sample): sample["position_ids"] = torch.arange(len(sample["input_ids"])) return sample @@ -517,6 +559,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, sample_packing=cfg.sample_packing if cfg.sample_packing else False, sample_packing_seq_len_multiplier=cfg.micro_batch_size, + relora_steps=cfg.relora_steps, + relora_warmup_steps=cfg.relora_warmup_steps, **training_arguments_kwargs, ) @@ -589,6 +633,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ callbacks = [] callbacks.append(GPUStatsCallback(cfg)) + + if cfg.relora_steps: + callbacks.append(ReLoRACallback(cfg)) + # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( @@ -633,11 +681,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ num_proc=32, ) - trainer_cls = ( - OneCycleLRSchedulerTrainer - if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") - else AxolotlTrainer - ) + trainer_cls = AxolotlTrainer + if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"): + trainer_cls = OneCycleLRSchedulerTrainer + elif cfg.relora_steps: + trainer_cls = ReLoRATrainer trainer = trainer_cls( model=model, train_dataset=train_dataset,