relora: magnitude pruning of the optimizer (#1245)
* magnitude pruning of the optimizer * add alpaca chat template and fix relora patch * fix handling of lora adapter for relora * fix merge and save call * fixes for 8-bit lora merge * save intermediate checkpoint adapters * auto merge * fix eval check * handle relora annealing * fix anneal step logic * chore: lint * misx fix * fix types * Update tests/e2e/test_relora_llama.py * check for safetensors saved from relora
This commit is contained in:
@@ -126,6 +126,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
@@ -478,10 +482,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
@@ -893,6 +901,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
|
||||
@@ -4,14 +4,16 @@ import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
@@ -23,23 +25,50 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
@torch.no_grad()
|
||||
def magnitude_pruning_(tensor, prune_ratio):
|
||||
tensor_magnitude = torch.abs(tensor)
|
||||
threshold = torch.quantile(
|
||||
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||
).to(dtype=tensor.dtype)
|
||||
|
||||
if key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
mask = tensor_magnitude > threshold
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
optimizer_state = optimizer.state
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for param in reset_params:
|
||||
param_state = optimizer_state[param]
|
||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
pruning_fn(
|
||||
param_state[key]
|
||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
||||
n_total += param_state[key].numel()
|
||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
@@ -97,6 +126,25 @@ class ReLoRACallback(TrainerCallback):
|
||||
"relora",
|
||||
)
|
||||
|
||||
if "adam" in args.optim.lower():
|
||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||
else:
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
lora_params = [
|
||||
n
|
||||
for n, p in model.named_parameters()
|
||||
if p.requires_grad and "lora_" in n
|
||||
]
|
||||
|
||||
model.save_pretrained(
|
||||
os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
"adapter",
|
||||
),
|
||||
safe_serialization=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
@@ -107,7 +155,11 @@ class ReLoRACallback(TrainerCallback):
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
)
|
||||
|
||||
if self.quantized:
|
||||
self.last_full_model = checkpoint_folder
|
||||
@@ -197,11 +249,13 @@ class ReLoRAScheduler(LRScheduler):
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.anneal_steps = anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
@@ -210,10 +264,20 @@ class ReLoRAScheduler(LRScheduler):
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
per_relora_progress = step % self.relora_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
@@ -238,7 +302,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
||||
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||
adapter = layer.active_adapter
|
||||
adapter: Union[List[str], str] = layer.active_adapter
|
||||
if isinstance(adapter, list):
|
||||
if len(adapter) > 1:
|
||||
raise ValueError("unhandled relora for multiple adapters")
|
||||
adapter = adapter[0]
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight.detach().to(device)
|
||||
@@ -248,7 +316,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
|
||||
return layer.get_delta_weight().to(device)
|
||||
raise ValueError("unhandled lora layer type")
|
||||
|
||||
|
||||
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
||||
@@ -273,9 +341,9 @@ def update_weights(
|
||||
):
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
# This could be faster, but the quantization of Linear4bit weights occurs
|
||||
@@ -286,7 +354,9 @@ def update_weights(
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
||||
target.weight.data = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||
)
|
||||
else:
|
||||
target.weight.data = new_weight.to(device)
|
||||
|
||||
@@ -304,14 +374,17 @@ def merge_and_save(
|
||||
|
||||
if not quantized:
|
||||
for module_name, target in modules.items():
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
active_adapter = target.active_adapter
|
||||
if isinstance(active_adapter, list):
|
||||
active_adapter = active_adapter[0]
|
||||
update = target.get_delta_weight(active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
@@ -363,6 +436,7 @@ def merge_and_save(
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
||||
|
||||
barrier()
|
||||
del in_tensors
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
33
src/axolotl/prompt_strategies/instruct.py
Normal file
33
src/axolotl/prompt_strategies/instruct.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||
# pylint: disable=duplicate-code
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return [
|
||||
{"from": "human", "value": prompt["instruction"]},
|
||||
{"from": "gpt", "value": prompt["output"]},
|
||||
]
|
||||
@@ -19,6 +19,7 @@ def chat_templates(user_choice: str):
|
||||
"""
|
||||
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
}
|
||||
|
||||
@@ -447,7 +447,11 @@ def validate_config(cfg):
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ def load_tokenized_prepared_datasets(
|
||||
+ "|".join(
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
||||
for d in cfg_datasets
|
||||
]
|
||||
)
|
||||
|
||||
@@ -8,7 +8,13 @@ import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
|
||||
from peft import (
|
||||
LoftQConfig,
|
||||
PeftConfig,
|
||||
PeftModel,
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
@@ -628,6 +634,9 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
@@ -782,7 +791,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||
from peft import AdaptionPromptConfig, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
@@ -828,7 +837,7 @@ def find_all_linear_names(model):
|
||||
def load_lora(model, cfg, inference=False, config_only=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user