diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5e52dafab..85c0dc7db 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -126,6 +126,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, ) + relora_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) bench_split: Optional[str] = field( default="eval", metadata={"help": "The benchmark split to run on"} ) @@ -478,10 +482,14 @@ class ReLoRATrainer(AxolotlTrainer): warmup_steps = ( self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 ) + anneal_steps = ( + self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 + ) self.lr_scheduler = ReLoRAScheduler( optimizer, lr_scheduler, self.args.relora_steps, + anneal_steps, warmup_steps, ) else: @@ -893,6 +901,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.micro_batch_size training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps + training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 9dac77e18..2d396e080 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -4,14 +4,16 @@ import json import logging import os.path import shutil +from functools import partial from pathlib import Path -from typing import Dict, List, Sequence +from typing import Dict, List, Sequence, Union import bitsandbytes as bnb import peft import safetensors.torch as st import torch from huggingface_hub import snapshot_download +from torch.distributed.optim import ZeroRedundancyOptimizer from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer from transformers import ( @@ -23,23 +25,50 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process +from axolotl.utils.distributed import barrier, is_main_process LOG = logging.getLogger("axolotl.relora") -def reset_optimizer(optimizer: torch.optim.Optimizer): - for group in optimizer.param_groups: - for param in group["params"]: - param_state = optimizer.state[param] - for key in param_state: - if "qmap" in key: - continue +@torch.no_grad() +def magnitude_pruning_(tensor, prune_ratio): + tensor_magnitude = torch.abs(tensor) + threshold = torch.quantile( + tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio + ).to(dtype=tensor.dtype) - if key == "step" and isinstance(param_state[key], int): - param_state[key] = 0 - else: - param_state[key] = torch.zeros_like(param_state[key]) + mask = tensor_magnitude > threshold + tensor.mul_(mask.to(dtype=tensor.dtype)) + + +def reset_optimizer( + optimizer: torch.optim.Optimizer, + *, + reset_params: list[str], # where str is the key to a torch.nn.Parameter + optimizer_state_keys: list[str], +): + pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9) + n_zeros = 0 + n_total = 0 + + optimizer_state = optimizer.state + if isinstance(optimizer, ZeroRedundancyOptimizer): + optimizer_state = optimizer.optim.state + + for param in reset_params: + param_state = optimizer_state[param] + if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer + continue + for key in optimizer_state_keys: + pruning_fn( + param_state[key] + ) # pruning fn has to be inplace to keep the same keys in the dict + n_total += param_state[key].numel() + n_zeros += torch.sum(param_state[key] == 0).item() + + _zeroed = n_zeros / (1e-7 + n_total) * 100 + LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}") + LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}") class ReLoRACallback(TrainerCallback): @@ -97,6 +126,25 @@ class ReLoRACallback(TrainerCallback): "relora", ) + if "adam" in args.optim.lower(): + optimizer_state_keys = ["exp_avg", "exp_avg_sq"] + else: + raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA") + + lora_params = [ + n + for n, p in model.named_parameters() + if p.requires_grad and "lora_" in n + ] + + model.save_pretrained( + os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + "adapter", + ), + safe_serialization=True, + ) with torch.no_grad(): merge_and_save( model, @@ -107,7 +155,11 @@ class ReLoRACallback(TrainerCallback): actually_save=is_main_process(), cpu_offload=self.cpu_offload, ) - reset_optimizer(optimizer) + reset_optimizer( + optimizer, + reset_params=lora_params, + optimizer_state_keys=optimizer_state_keys, + ) if self.quantized: self.last_full_model = checkpoint_folder @@ -197,11 +249,13 @@ class ReLoRAScheduler(LRScheduler): inner_schedule: LRScheduler, relora_steps: int, warmup_steps: int, + anneal_steps: int = 1, min_lr_scale: float = 0.001, ) -> None: self.inner_schedule = inner_schedule self.relora_steps = relora_steps self.warmup_steps = warmup_steps + self.anneal_steps = anneal_steps self.min_lr_scale = min_lr_scale super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) @@ -210,10 +264,20 @@ class ReLoRAScheduler(LRScheduler): original = self.inner_schedule.get_lr() step = self.last_epoch + if step < self.relora_steps: scale = 1 else: - cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps) + per_relora_progress = step % self.relora_steps + if per_relora_progress < self.warmup_steps: + cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps) + elif per_relora_progress > (self.relora_steps - self.anneal_steps): + cycle_t = min( + 1.0, + (self.relora_steps - per_relora_progress) / self.anneal_steps, + ) + else: + cycle_t = 1 scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale if isinstance(original, Sequence): @@ -238,7 +302,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor: if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): - adapter = layer.active_adapter + adapter: Union[List[str], str] = layer.active_adapter + if isinstance(adapter, list): + if len(adapter) > 1: + raise ValueError("unhandled relora for multiple adapters") + adapter = adapter[0] return ( peft.utils.transpose( layer.lora_B[adapter].weight.detach().to(device) @@ -248,7 +316,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor * layer.scaling[adapter] ) - return layer.get_delta_weight().to(device) + raise ValueError("unhandled lora layer type") def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]: @@ -273,9 +341,9 @@ def update_weights( ): if reinit: for adapter_name in target.lora_A: - target.reset_lora_parameters(adapter_name) + target.reset_lora_parameters(adapter_name, True) for adapter_name in target.lora_embedding_A: - target.reset_lora_parameters(adapter_name) + target.reset_lora_parameters(adapter_name, True) if isinstance(target, peft.tuners.lora.Linear4bit): # This could be faster, but the quantization of Linear4bit weights occurs @@ -286,7 +354,9 @@ def update_weights( target.weight.data = new_weight.cpu() target.to(device) elif isinstance(target, peft.tuners.lora.Linear8bitLt): - target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device) + target.weight.data = ( + bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data + ) else: target.weight.data = new_weight.to(device) @@ -304,14 +374,17 @@ def merge_and_save( if not quantized: for module_name, target in modules.items(): - update = target.get_delta_weight(target.active_adapter).detach() + active_adapter = target.active_adapter + if isinstance(active_adapter, list): + active_adapter = active_adapter[0] + update = target.get_delta_weight(active_adapter).detach() target.weight.data += update if reinit: for adapter_name in target.lora_A: - target.reset_lora_parameters(adapter_name) + target.reset_lora_parameters(adapter_name, True) for adapter_name in target.lora_embedding_A: - target.reset_lora_parameters(adapter_name) + target.reset_lora_parameters(adapter_name, True) return os.makedirs(model_dst, exist_ok=True) @@ -363,6 +436,7 @@ def merge_and_save( LOG.info(f"saving tensors to {shard_fn}") st.save_file(out_tensors, shard_fn, metadata={"format": "pt"}) + barrier() del in_tensors del out_tensors torch.cuda.empty_cache() diff --git a/src/axolotl/prompt_strategies/instruct.py b/src/axolotl/prompt_strategies/instruct.py new file mode 100644 index 000000000..3d6367489 --- /dev/null +++ b/src/axolotl/prompt_strategies/instruct.py @@ -0,0 +1,33 @@ +"""Module containing the InstructShareGPTPromptTokenizingStrategy class""" +from typing import Any, Dict, Optional + +from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy +from axolotl.prompters import ShareGPTPrompterV2 + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + strategy = InstructShareGPTPromptTokenizingStrategy( + # pylint: disable=duplicate-code + ShareGPTPrompterV2( + conversation=conversation, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + return strategy + + +class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): + """ + basic sharegpt strategy to grab conversations from the sample row + """ + + def get_conversation_thread(self, prompt): + return [ + {"from": "human", "value": prompt["instruction"]}, + {"from": "gpt", "value": prompt["output"]}, + ] diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 2470809d4..bcd20fb3a 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -19,6 +19,7 @@ def chat_templates(user_choice: str): """ templates = { + "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", } diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 75b4e5220..5c56db9f1 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -447,7 +447,11 @@ def validate_config(cfg): "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps." ) - if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy): + if ( + cfg.val_set_size == 0 + and (cfg.eval_steps or cfg.evaluation_strategy) + and not cfg.test_datasets + ): raise ValueError( "eval_steps and evaluation_strategy are not supported with val_set_size == 0" ) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5e6ceb6cb..105e7416c 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -140,7 +140,7 @@ def load_tokenized_prepared_datasets( + "|".join( sorted( [ - f"{d.path}:{d.type}:{d.shards}:{d.conversation}" + f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}" for d in cfg_datasets ] ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 52a81ea2c..44a93a36b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -8,7 +8,13 @@ import addict import bitsandbytes as bnb import torch import transformers -from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training +from peft import ( + LoftQConfig, + PeftConfig, + PeftModel, + PeftModelForCausalLM, + prepare_model_for_kbit_training, +) from peft.tuners.lora import QuantLinear from transformers import ( # noqa: F401 AddedToken, @@ -628,6 +634,9 @@ def load_model( LOG.exception(err) raise err + if isinstance(model, (PeftModel, PeftModelForCausalLM)): + model = model.merge_and_unload() + embeddings_len = ( math.ceil(len(tokenizer) / 32) * 32 if cfg.resize_token_embeddings_to_32x @@ -782,7 +791,7 @@ def load_adapter(model, cfg, adapter, inference=False): def load_llama_adapter(model, cfg): # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import AdaptionPromptConfig, PeftModel, get_peft_model + from peft import AdaptionPromptConfig, get_peft_model peft_config = AdaptionPromptConfig( adapter_layers=cfg.peft_adapter.layers, # layers (L) @@ -828,7 +837,7 @@ def find_all_linear_names(model): def load_lora(model, cfg, inference=False, config_only=False): # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] - from peft import LoraConfig, PeftModel, get_peft_model + from peft import LoraConfig, get_peft_model lora_target_modules = list(cfg.lora_target_modules or []) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index c0327d7ef..a56c530b2 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -7,8 +7,6 @@ import os import unittest from pathlib import Path -from transformers.utils import is_torch_bf16_gpu_available - from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train @@ -63,6 +61,7 @@ class TestMistral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "bf16": "auto", } ) normalize_config(cfg) @@ -103,12 +102,9 @@ class TestMistral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "bf16": "auto", } ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py new file mode 100644 index 000000000..4ba130c9d --- /dev/null +++ b/tests/e2e/test_relora_llama.py @@ -0,0 +1,68 @@ +""" +E2E tests for relora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestReLoraLlama(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_relora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_modules": ["q_proj", "v_proj"], + "relora_steps": 25, + "relora_warmup_steps": 5, + "relora_anneal_steps": 5, + "relora_cpu_offload": True, + "val_set_size": 0.0, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "warmup_steps": 15, + "num_epochs": 2, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()