relora: magnitude pruning of the optimizer (#1245)
* magnitude pruning of the optimizer * add alpaca chat template and fix relora patch * fix handling of lora adapter for relora * fix merge and save call * fixes for 8-bit lora merge * save intermediate checkpoint adapters * auto merge * fix eval check * handle relora annealing * fix anneal step logic * chore: lint * misx fix * fix types * Update tests/e2e/test_relora_llama.py * check for safetensors saved from relora
This commit is contained in:
@@ -126,6 +126,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
)
|
)
|
||||||
|
relora_anneal_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
bench_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
@@ -478,10 +482,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
warmup_steps = (
|
warmup_steps = (
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
)
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
optimizer,
|
optimizer,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
self.args.relora_steps,
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
warmup_steps,
|
warmup_steps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -893,6 +901,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||||
|
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,14 +4,16 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os.path
|
import os.path
|
||||||
import shutil
|
import shutil
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Sequence
|
from typing import Dict, List, Sequence, Union
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import peft
|
import peft
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -23,23 +25,50 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.relora")
|
LOG = logging.getLogger("axolotl.relora")
|
||||||
|
|
||||||
|
|
||||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
@torch.no_grad()
|
||||||
for group in optimizer.param_groups:
|
def magnitude_pruning_(tensor, prune_ratio):
|
||||||
for param in group["params"]:
|
tensor_magnitude = torch.abs(tensor)
|
||||||
param_state = optimizer.state[param]
|
threshold = torch.quantile(
|
||||||
for key in param_state:
|
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||||
if "qmap" in key:
|
).to(dtype=tensor.dtype)
|
||||||
continue
|
|
||||||
|
|
||||||
if key == "step" and isinstance(param_state[key], int):
|
mask = tensor_magnitude > threshold
|
||||||
param_state[key] = 0
|
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||||
else:
|
|
||||||
param_state[key] = torch.zeros_like(param_state[key])
|
|
||||||
|
def reset_optimizer(
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
*,
|
||||||
|
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||||
|
optimizer_state_keys: list[str],
|
||||||
|
):
|
||||||
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
|
||||||
|
n_zeros = 0
|
||||||
|
n_total = 0
|
||||||
|
|
||||||
|
optimizer_state = optimizer.state
|
||||||
|
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||||
|
optimizer_state = optimizer.optim.state
|
||||||
|
|
||||||
|
for param in reset_params:
|
||||||
|
param_state = optimizer_state[param]
|
||||||
|
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
||||||
|
continue
|
||||||
|
for key in optimizer_state_keys:
|
||||||
|
pruning_fn(
|
||||||
|
param_state[key]
|
||||||
|
) # pruning fn has to be inplace to keep the same keys in the dict
|
||||||
|
n_total += param_state[key].numel()
|
||||||
|
n_zeros += torch.sum(param_state[key] == 0).item()
|
||||||
|
|
||||||
|
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||||
|
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||||
|
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
|
||||||
|
|
||||||
|
|
||||||
class ReLoRACallback(TrainerCallback):
|
class ReLoRACallback(TrainerCallback):
|
||||||
@@ -97,6 +126,25 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
"relora",
|
"relora",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "adam" in args.optim.lower():
|
||||||
|
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||||
|
|
||||||
|
lora_params = [
|
||||||
|
n
|
||||||
|
for n, p in model.named_parameters()
|
||||||
|
if p.requires_grad and "lora_" in n
|
||||||
|
]
|
||||||
|
|
||||||
|
model.save_pretrained(
|
||||||
|
os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
|
"adapter",
|
||||||
|
),
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
merge_and_save(
|
merge_and_save(
|
||||||
model,
|
model,
|
||||||
@@ -107,7 +155,11 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
actually_save=is_main_process(),
|
actually_save=is_main_process(),
|
||||||
cpu_offload=self.cpu_offload,
|
cpu_offload=self.cpu_offload,
|
||||||
)
|
)
|
||||||
reset_optimizer(optimizer)
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=lora_params,
|
||||||
|
optimizer_state_keys=optimizer_state_keys,
|
||||||
|
)
|
||||||
|
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
self.last_full_model = checkpoint_folder
|
self.last_full_model = checkpoint_folder
|
||||||
@@ -197,11 +249,13 @@ class ReLoRAScheduler(LRScheduler):
|
|||||||
inner_schedule: LRScheduler,
|
inner_schedule: LRScheduler,
|
||||||
relora_steps: int,
|
relora_steps: int,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
|
anneal_steps: int = 1,
|
||||||
min_lr_scale: float = 0.001,
|
min_lr_scale: float = 0.001,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.inner_schedule = inner_schedule
|
self.inner_schedule = inner_schedule
|
||||||
self.relora_steps = relora_steps
|
self.relora_steps = relora_steps
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
|
self.anneal_steps = anneal_steps
|
||||||
self.min_lr_scale = min_lr_scale
|
self.min_lr_scale = min_lr_scale
|
||||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||||
|
|
||||||
@@ -210,10 +264,20 @@ class ReLoRAScheduler(LRScheduler):
|
|||||||
|
|
||||||
original = self.inner_schedule.get_lr()
|
original = self.inner_schedule.get_lr()
|
||||||
step = self.last_epoch
|
step = self.last_epoch
|
||||||
|
|
||||||
if step < self.relora_steps:
|
if step < self.relora_steps:
|
||||||
scale = 1
|
scale = 1
|
||||||
else:
|
else:
|
||||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
per_relora_progress = step % self.relora_steps
|
||||||
|
if per_relora_progress < self.warmup_steps:
|
||||||
|
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||||
|
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
||||||
|
cycle_t = min(
|
||||||
|
1.0,
|
||||||
|
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cycle_t = 1
|
||||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||||
|
|
||||||
if isinstance(original, Sequence):
|
if isinstance(original, Sequence):
|
||||||
@@ -238,7 +302,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
|||||||
|
|
||||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
||||||
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||||
adapter = layer.active_adapter
|
adapter: Union[List[str], str] = layer.active_adapter
|
||||||
|
if isinstance(adapter, list):
|
||||||
|
if len(adapter) > 1:
|
||||||
|
raise ValueError("unhandled relora for multiple adapters")
|
||||||
|
adapter = adapter[0]
|
||||||
return (
|
return (
|
||||||
peft.utils.transpose(
|
peft.utils.transpose(
|
||||||
layer.lora_B[adapter].weight.detach().to(device)
|
layer.lora_B[adapter].weight.detach().to(device)
|
||||||
@@ -248,7 +316,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
|
|||||||
* layer.scaling[adapter]
|
* layer.scaling[adapter]
|
||||||
)
|
)
|
||||||
|
|
||||||
return layer.get_delta_weight().to(device)
|
raise ValueError("unhandled lora layer type")
|
||||||
|
|
||||||
|
|
||||||
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
||||||
@@ -273,9 +341,9 @@ def update_weights(
|
|||||||
):
|
):
|
||||||
if reinit:
|
if reinit:
|
||||||
for adapter_name in target.lora_A:
|
for adapter_name in target.lora_A:
|
||||||
target.reset_lora_parameters(adapter_name)
|
target.reset_lora_parameters(adapter_name, True)
|
||||||
for adapter_name in target.lora_embedding_A:
|
for adapter_name in target.lora_embedding_A:
|
||||||
target.reset_lora_parameters(adapter_name)
|
target.reset_lora_parameters(adapter_name, True)
|
||||||
|
|
||||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||||
# This could be faster, but the quantization of Linear4bit weights occurs
|
# This could be faster, but the quantization of Linear4bit weights occurs
|
||||||
@@ -286,7 +354,9 @@ def update_weights(
|
|||||||
target.weight.data = new_weight.cpu()
|
target.weight.data = new_weight.cpu()
|
||||||
target.to(device)
|
target.to(device)
|
||||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||||
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
target.weight.data = (
|
||||||
|
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
target.weight.data = new_weight.to(device)
|
target.weight.data = new_weight.to(device)
|
||||||
|
|
||||||
@@ -304,14 +374,17 @@ def merge_and_save(
|
|||||||
|
|
||||||
if not quantized:
|
if not quantized:
|
||||||
for module_name, target in modules.items():
|
for module_name, target in modules.items():
|
||||||
update = target.get_delta_weight(target.active_adapter).detach()
|
active_adapter = target.active_adapter
|
||||||
|
if isinstance(active_adapter, list):
|
||||||
|
active_adapter = active_adapter[0]
|
||||||
|
update = target.get_delta_weight(active_adapter).detach()
|
||||||
target.weight.data += update
|
target.weight.data += update
|
||||||
|
|
||||||
if reinit:
|
if reinit:
|
||||||
for adapter_name in target.lora_A:
|
for adapter_name in target.lora_A:
|
||||||
target.reset_lora_parameters(adapter_name)
|
target.reset_lora_parameters(adapter_name, True)
|
||||||
for adapter_name in target.lora_embedding_A:
|
for adapter_name in target.lora_embedding_A:
|
||||||
target.reset_lora_parameters(adapter_name)
|
target.reset_lora_parameters(adapter_name, True)
|
||||||
return
|
return
|
||||||
|
|
||||||
os.makedirs(model_dst, exist_ok=True)
|
os.makedirs(model_dst, exist_ok=True)
|
||||||
@@ -363,6 +436,7 @@ def merge_and_save(
|
|||||||
LOG.info(f"saving tensors to {shard_fn}")
|
LOG.info(f"saving tensors to {shard_fn}")
|
||||||
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
barrier()
|
||||||
del in_tensors
|
del in_tensors
|
||||||
del out_tensors
|
del out_tensors
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
33
src/axolotl/prompt_strategies/instruct.py
Normal file
33
src/axolotl/prompt_strategies/instruct.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||||
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
basic sharegpt strategy to grab conversations from the sample row
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_conversation_thread(self, prompt):
|
||||||
|
return [
|
||||||
|
{"from": "human", "value": prompt["instruction"]},
|
||||||
|
{"from": "gpt", "value": prompt["output"]},
|
||||||
|
]
|
||||||
@@ -19,6 +19,7 @@ def chat_templates(user_choice: str):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
templates = {
|
templates = {
|
||||||
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -447,7 +447,11 @@ def validate_config(cfg):
|
|||||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
if (
|
||||||
|
cfg.val_set_size == 0
|
||||||
|
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||||
|
and not cfg.test_datasets
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
+ "|".join(
|
+ "|".join(
|
||||||
sorted(
|
sorted(
|
||||||
[
|
[
|
||||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
||||||
for d in cfg_datasets
|
for d in cfg_datasets
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,13 @@ import addict
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
|
from peft import (
|
||||||
|
LoftQConfig,
|
||||||
|
PeftConfig,
|
||||||
|
PeftModel,
|
||||||
|
PeftModelForCausalLM,
|
||||||
|
prepare_model_for_kbit_training,
|
||||||
|
)
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
@@ -628,6 +634,9 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
embeddings_len = (
|
embeddings_len = (
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
if cfg.resize_token_embeddings_to_32x
|
if cfg.resize_token_embeddings_to_32x
|
||||||
@@ -782,7 +791,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
|
|
||||||
def load_llama_adapter(model, cfg):
|
def load_llama_adapter(model, cfg):
|
||||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
from peft import AdaptionPromptConfig, get_peft_model
|
||||||
|
|
||||||
peft_config = AdaptionPromptConfig(
|
peft_config = AdaptionPromptConfig(
|
||||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||||
@@ -828,7 +837,7 @@ def find_all_linear_names(model):
|
|||||||
def load_lora(model, cfg, inference=False, config_only=False):
|
def load_lora(model, cfg, inference=False, config_only=False):
|
||||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
|
|
||||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -63,6 +61,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
@@ -103,12 +102,9 @@ class TestMistral(unittest.TestCase):
|
|||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
68
tests/e2e/test_relora_llama.py
Normal file
68
tests/e2e/test_relora_llama.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for relora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReLoraLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_relora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_modules": ["q_proj", "v_proj"],
|
||||||
|
"relora_steps": 25,
|
||||||
|
"relora_warmup_steps": 5,
|
||||||
|
"relora_anneal_steps": 5,
|
||||||
|
"relora_cpu_offload": True,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"warmup_steps": 15,
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
Reference in New Issue
Block a user