Compare commits

..

29 Commits

Author SHA1 Message Date
Wing Lian
791c38dcc3 chore: lint 2025-01-24 13:29:54 -05:00
Wing Lian
0af78a9882 rescale the norm for lora 2025-01-24 13:11:26 -05:00
Wing Lian
fa5efbf235 don't scale delta before decomposing 2025-01-24 13:11:26 -05:00
Wing Lian
59a7ac427d make sure to scale too 2025-01-24 13:11:25 -05:00
Wing Lian
e3393042e5 hopefully fix the lora/dora logic 2025-01-24 13:11:25 -05:00
Wing Lian
08a4e8a7fb refactor a bit 2025-01-24 13:11:25 -05:00
Wing Lian
b582d340b0 save tokenizer too 2025-01-24 13:11:25 -05:00
Wing Lian
474ba1a1b8 chore: lint/formatting 2025-01-24 13:11:25 -05:00
Wing Lian
de771fcb05 fix convert logger and registration 2025-01-24 13:11:25 -05:00
Wing Lian
f32d429db5 fix import path to args 2025-01-24 13:11:25 -05:00
Wing Lian
82005f8eeb auto modeling for rrt 2025-01-24 13:11:25 -05:00
Wing Lian
b439ed3345 support optional dora 2025-01-24 13:11:24 -05:00
Wing Lian
623eaca740 more fixes to conversion 2025-01-24 13:11:24 -05:00
Wing Lian
38dfd3fadb wip conversion cli 2025-01-24 13:11:24 -05:00
Wing Lian
daa9408233 more wip 2025-01-24 13:11:24 -05:00
Wing Lian
257231ac46 wip rrt 2025-01-24 13:11:24 -05:00
Wing Lian
887513285d support for custom lr groups for non-embedding modules (#2213)
* support for custom lr groups for non-embedding modules

invert name check for group modules
include lr_groups in training args
additional conditional for creating optimizer
fix regular params as w weight decay
fix lookup and add docs

* address pr feedback
2025-01-24 12:56:28 -05:00
Wing Lian
20620771f1 Pretrain multipack (#2278)
* fix for pretrain with packing

* fix model name and loss expected

* make sure to check with micro batch size for pretraining

* change loss threshholds based on parametrization

* make tests smaller for CI

* fix pretrain packing

* fix pretrain packing test

* address pr feedback
2025-01-24 12:55:20 -05:00
NanoCode012
6086162488 chore(doc): improve explanation for *_steps and *_strategy (#2270) 2025-01-24 10:07:02 -05:00
mashdragon
b2774af66c Take split param from config in all load_dataset instances (#2281) 2025-01-24 10:06:50 -05:00
NanoCode012
74f9782fc3 chore(doc): fix explanation on gcs creds retrieval (#2272) 2025-01-24 10:05:58 -05:00
Wing Lian
8a7a0b07dc support for latest transformers release 4.48.1 (#2256) 2025-01-23 21:17:57 -05:00
Wing Lian
8fb72cbc0b use the extracted field_messages to parse the role fields (#2265) 2025-01-21 15:39:30 -05:00
Adithya Kamath
bb9d4102c4 Add 5000 line history limit to tmux for docker cloud (#2268) 2025-01-21 15:39:17 -05:00
Wing Lian
af727eedf7 option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
2025-01-20 14:07:34 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
Dan Saunders
1ed4de73b6 CLI cleanup and documentation (#2244)
* CLI init refactor

* fix

* cleanup and (partial) docs

* Adding documentation and continuing cleanup (in progress)

* remove finetune.py script

* continued cleanup and documentation

* pytest fixes

* review comments

* fix

* Fix

* typing fixes

* make sure the batch dataset patcher for multipack is always loaded when handling datasets

* review comments

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-01-13 17:55:29 +00:00
38 changed files with 1382 additions and 472 deletions

View File

@@ -519,8 +519,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
# s3 creds will be loaded from the system default / gcs will attempt to load from gcloud creds, google metadata service, or anon
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above
...
# Loading Data From a Public URL

View File

@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -244,6 +244,8 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
@@ -358,10 +360,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `"no"` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that

29
docs/lr_groups.qmd Normal file
View File

@@ -0,0 +1,29 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

View File

@@ -13,9 +13,9 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.47.1
transformers==4.48.1
tokenizers>=0.21.0
accelerate==1.2.1
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.13.0

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_fields = features[field_messages][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:

View File

@@ -11,7 +11,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
@@ -103,9 +103,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for DPO training, calling
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
information.
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -115,7 +115,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

View File

@@ -243,6 +243,10 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
@@ -461,11 +465,95 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
@@ -479,59 +567,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -548,6 +590,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -1079,6 +1122,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
@@ -1664,6 +1708,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1877,6 +1922,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -48,9 +48,9 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Registers the plugin
Parameters:
cfg (dict): The configuration for the plugin.
@@ -274,6 +274,7 @@ class PluginManager:
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
plugin.register()
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")

View File

View File

@@ -0,0 +1,25 @@
"""
Axolotl Plugin for Relaxed Recursive Transformers
"""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model
LOG = logging.getLogger(__name__)
class RelaxedRecursiveTransformerPlugin(BasePlugin):
"""
Plugin for Relaxed Recursive Transformers integration with Axolotl
"""
def get_input_args(self):
return "axolotl.integrations.rrt.args.RelaxedRecursiveTransformerArgs"
def register(self):
LOG.info(
"Registering Relaxed Recursive Transformers modeling with transformers"
)
register_rrt_model()

View File

@@ -0,0 +1,11 @@
"""
Axolotl config args for Relaxed Recursive Transformers plugin
"""
from pydantic import BaseModel
class RelaxedRecursiveTransformerArgs(BaseModel):
"""
Arguments pertaining to the Relaxed Recursive Transformer model.
"""

View File

@@ -0,0 +1,370 @@
"""
cli script for converting a pretrained model to a relaxed recursive transformer model
"""
import json
import logging
import math
import os
import re
from pathlib import Path
from typing import Tuple
import safetensors
import torch
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
)
logger = logging.getLogger(__name__)
def extract_layer_number(key):
"""Extract layer number from parameter key."""
match = re.search(r"layers\.(\d+)\.", key)
return int(match.group(1)) if match else None
def iter_parameter_weights(model_path, device="mps"):
"""
iterator over parameter weights in the model shards
:param model_path: Path to model shards
:param device: Computing device
:return: generator yielding (parameter key, parameter weight, layer index) tuples
"""
shards = list(model_path.glob("model*.safetensors"))
if not shards:
raise ValueError(f"No model shards found in {model_path}")
for shard in tqdm(shards, desc="Processing shards"):
with safetensors.safe_open(shard, framework="pt", device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
def iter_recursive_parameter_weights(
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
continue
recurse_idx = layer_idx % recurse_layers
suffix = f"{recurse_idx}.{matched_module_name}"
if rrt_avg_model_state_dict.get(suffix) is None:
# setup as storage for suffix with torch.stack
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
else:
rrt_avg_model_state_dict[suffix].append(
weight.to(torch.float32).detach().cpu()
)
for module_name in modules_to_recurse:
for recurse_idx in range(recurse_layers):
suffix = f"{recurse_idx}.{module_name}"
prefix = f"model.layers.{suffix}"
avg_weight = torch.stack(rrt_avg_model_state_dict[suffix]).mean(dim=0)
yield f"{prefix}.weight_base", avg_weight
# compute the decomposed lora diff from the weight base to the actual weight for each module
def low_rank_decomposition(
weight: torch.Tensor, max_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decompose a 2D matrix into low-rank matrices L and R using SVD.
:param weight: The matrix to decompose, of shape (H, W)
:param max_rank: The maximum rank of the decomposition
:return: A tuple of tensors (L, R)
"""
# pylint: disable=invalid-name
assert (
weight.dim() == 2
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
assert (
max_rank >= 1
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
dtype = weight.dtype
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
# Distribute S to both to improve numerical precision
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
return A.to(dtype), B.to(dtype)
def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
"""
Decompose the difference in directions (ΔV) via SVD,
and return (magnitudes, L, R).
"""
device = "cuda" if torch.cuda.is_available() else "mps"
# rslora
scaling = alpha / math.sqrt(rank)
base_weight = avg_weight.to(device)
final_weight = layer_weight.to(device)
delta_for_svd = final_weight - base_weight
# Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name
delta_for_svd, rank
)
if use_dora:
lora_weight = lora_B @ lora_A
weight_norm = get_weight_norm(
base_weight.to(lora_A.device), lora_weight, scaling
)
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
# let's rescale the lora weight to have the same magnitude as the base weight
return lora_A.cpu(), lora_B.cpu(), None
def iter_dora_parameter_weights(
model_path,
avg_recursive_weights,
modules_to_recurse: list[str],
alpha,
rank,
device="mps",
recurse_layers=12,
use_dora=True,
):
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
if "input_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = (
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
)
yield layernorm_key, weight
elif "post_attention_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}.weight"
yield layernorm_key, weight
else:
yield key, weight
continue
# figure out the base weight layer for this key
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
suffix = f"{layer_idx}.{matched_module_name}"
prefix = f"model.layers.{suffix}.weight_base"
avg_weight = avg_recursive_weights[prefix]
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
lora_magnitude_key = (
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
)
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
weight,
avg_weight,
alpha,
rank,
use_dora=use_dora,
)
yield lora_a_key, lora_a
yield lora_b_key, lora_b
if use_dora:
yield lora_magnitude_key, lora_magnitude
def save_state_dict_to_safetensors(state_dict, save_directory):
os.makedirs(save_directory, exist_ok=True)
weights_name = SAFE_WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
)
# pylint: disable=duplicate-code
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
del state_dict[tensor]
save_file(
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
)
del state_dict
if index is None:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
else:
save_index_file = SAFE_WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
def convert_llama_to_rrt(
model_name,
output_dir,
recurse_layers: int = 12,
rank=32,
alpha=32,
device=None,
use_dora=True,
):
if not device:
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.down_proj",
"mlp.gate_proj",
"mlp.up_proj",
]
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers % recurse_layers != 0:
raise ValueError(
f"The number of hidden layers ({num_hidden_layers}) in the model must be "
f"divisible by the recurse layers ({recurse_layers})"
)
config = RelaxedRecursiveLlamaConfig.from_dict(
{
**config.to_dict(),
"recurse_layers": recurse_layers,
"rank": rank,
"alpha": alpha,
"use_dora": use_dora,
}
)
config.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
logger.info("Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
logger.info("Calculating decomposed lora diff...")
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(
model_path,
rrt_model_state_dict,
modules_to_recurse,
alpha=32,
rank=rank,
device=device,
recurse_layers=recurse_layers,
use_dora=use_dora,
):
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# combine state dicts into a single state_dict
rrt_model_state_dict.update(rrt_lora_state_dict)
# save state dict as sharded safetensors to disk using split_torch_state_dict_into_shards
save_state_dict_to_safetensors(rrt_model_state_dict, output_dir)
if __name__ == "__main__":
# meta-llama/Llama-3.2-1B has 16 hidden layers
# meta-llama/Llama-3.2-3B has 28 hidden layers
convert_llama_to_rrt(
"meta-llama/Llama-3.2-3B",
"/tmp/rrt_model", # nosec
recurse_layers=4,
rank=256,
alpha=512,
use_dora=False,
)

View File

@@ -0,0 +1,25 @@
"""
module for modeling relaxed recursive transformers model
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
from .modeling_rrt_llama import (
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
def register_rrt_model():
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -0,0 +1,16 @@
"""
module for custom configuration for relaxed recursive transformers model
"""
from transformers import LlamaConfig
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
model_type: str = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True

View File

@@ -0,0 +1,116 @@
"""
module for the shared linear layer for the relaxed recursive transformers model
"""
import math
import torch
import torch.nn.functional as F
from peft.utils import transpose
from torch import nn
class RelaxedRecursiveDoraLinear(nn.Module):
"""
A single linear layer that is "shared" across multiple loop iterations,
but each iteration has its own DoRA offsets (A_i, B_i, magnitude_i).
The constructor expects you to specify:
- in_features, out_features
- B: number of loop iterations (i.e., how many times we "unroll")
- fan_in_fan_out: pass True if your underlying base weight is transposed, etc.
The forward(...) expects an additional argument "loop_idx" in [0..B-1],
which picks out the iteration-specific DoRA offsets.
"""
def __init__(
self,
in_features: int,
out_features: int,
B: int, # pylint: disable=invalid-name
rank: int,
alpha: int,
fan_in_fan_out: bool = False,
bias: bool = True,
use_dora: bool = True,
):
super().__init__()
self.B = B # pylint: disable=invalid-name
self.fan_in_fan_out = fan_in_fan_out
self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
self.use_bias = bias
if self.use_bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
)
self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
)
# rslora
self.scaling = alpha / math.sqrt(rank)
self.use_dora = use_dora
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList(
[nn.Parameter(torch.ones(out_features)) for _ in range(B)]
)
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = transpose(weight, self.fan_in_fan_out)
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
def forward(self, x, loop_idx: int):
"""
:param x: hidden state of shape (batch_size, seq_len, in_features)
:param loop_idx:
:return:
"""
eps = 1e-6
w_base = self.weight_base
w_base = w_base.to(x.dtype)
lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name
loop_idx
]
lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name
loop_idx
]
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
if self.use_dora:
x_eye: torch.Tensor = torch.eye(
lora_A.shape[1], device=lora_A.device, dtype=x.dtype
)
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
w_dora_norm: torch.Tensor = self.get_weight_norm(
w_base, w_dora_full.detach(), self.scaling
)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(
0
) # shape [1, out_features]
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora
# scale the lora norm to prevent gradient explosion
orig_norm = torch.linalg.norm(w_base)
update_norm = torch.linalg.norm(lora_out)
scale = orig_norm / (update_norm + eps)
return base_out + lora_out * scale

View File

@@ -0,0 +1,471 @@
import logging
from typing import Callable, Optional, Tuple, Union, Unpack
import torch
from torch import nn
from transformers import Cache, DynamicCache, LlamaConfig
from transformers.activations import ACT2FN
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__)
# pylint: skip-file
# mypy: ignore-errors
class RelaxedRecursiveLlamaMLP(nn.Module):
def __init__(self, config: RelaxedRecursiveLlamaConfig):
super().__init__()
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.up_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.down_proj = RelaxedRecursiveDoraLinear(
self.intermediate_size,
self.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx),
loop_idx,
)
return down_proj
class RelaxedRecursiveLlamaAttention(nn.Module):
"""
A single attention layer of the Relaxed Recursive Llama.
"""
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
super().__init__()
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
loop_idx: int,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = (
self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
logger.warning(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, loop_idx)
return attn_output, attn_weights # pylint: disable=return-value
class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
"""
A single layer of the Relaxed Recursive Llama decoder.
"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.hidden_size = config.hidden_size
self.self_attn = RelaxedRecursiveLlamaAttention(
config=config, layer_idx=layer_idx
)
self.mlp = RelaxedRecursiveLlamaMLP(config)
self.input_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
self.post_attention_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
def forward(
self,
hidden_states: torch.Tensor,
loop_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
loop_idx=loop_idx,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm_list[loop_idx](hidden_states)
hidden_states = self.mlp(hidden_states, loop_idx)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class RelaxedRecursiveLlamaModel(LlamaModel):
config_class = RelaxedRecursiveLlamaConfig
def __init__(self, config):
super(LlamaModel, self).__init__(config)
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
RelaxedRecursiveLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.recurse_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for loop_idx in range(self.recurse_loops):
for decoder_layer in self.layers[: self.config.recurse_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
loop_idx,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
loop_idx,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
config_class = RelaxedRecursiveLlamaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = RelaxedRecursiveLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_nb_trainable_parameters(self) -> tuple[int, int, int]:
r"""
Returns the number of trainable parameters and the number of all parameters in the model.
"""
trainable_params = 0
all_param = 0
lora_params = 0
for name, param in self.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes
# one needs to multiply the number of parameters by 2 to get
# the correct number of parameters
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "element_size"):
num_bytes = param.element_size()
elif not hasattr(param, "quant_storage"):
num_bytes = 1
else:
num_bytes = param.quant_storage.itemsize
num_params = num_params * 2 * num_bytes
all_param += num_params
if param.requires_grad:
trainable_params += num_params
if "lora_" in name:
lora_params += num_params
return trainable_params, all_param, lora_params

View File

@@ -1,308 +0,0 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""
import transformers.modeling_flash_attention_utils
def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)
return _flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)

View File

@@ -0,0 +1,67 @@
"""
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
def fixed_fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
preferred_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
preferred_dtype (`torch.dtype`, *optional*):
The preferred dtype to convert the attention tensors to regardless of the
target dtype.
"""
if target_dtype is None and preferred_dtype is None:
return query, key, value
if preferred_dtype and target_dtype != preferred_dtype:
target_dtype = preferred_dtype
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
if any(module.dtype == torch.float32 for module in [query, key, value]):
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
def patch_fa_peft_integration():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
fixed_fa_peft_integration_check, preferred_dtype=None
)

View File

@@ -147,6 +147,14 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel):
"""SFT configuration subset"""
@@ -475,6 +483,7 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None
@@ -706,6 +715,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,

View File

@@ -18,10 +18,14 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -30,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
@@ -180,7 +191,7 @@ def wrap_pretraining_dataset(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * batch_size,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
encode = functools.partial(
@@ -190,13 +201,17 @@ def wrap_pretraining_dataset(
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
@@ -230,9 +245,7 @@ def encode_packed_pretraining(
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = False,
group_size: int = 100000,
bin_size: int = 200,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
@@ -243,6 +256,9 @@ def encode_packed_pretraining(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
)
sampler = MultipackBatchSampler(
@@ -250,8 +266,6 @@ def encode_packed_pretraining(
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
batch_max_len=batch_size * max_seq_length,
group_size=group_size,
bin_size=bin_size,
drop_last=True,
)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg):
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError):
pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
try:
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
else:
if isinstance(config_dataset.data_files, str):
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")

View File

@@ -380,23 +380,19 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -1057,7 +1053,7 @@ class ModelLoader:
)
if (
hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings < embeddings_len
and self.model.get_input_embeddings().num_embeddings != embeddings_len
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -310,19 +310,22 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
if skip_position_ids:
if not skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,
desc="Add position_id column (Pretraining Sample Packing)",
)
if drop_attention_mask:
train_dataset = train_dataset.remove_columns("attention_mask")
return train_dataset

View File

@@ -63,6 +63,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
@@ -127,6 +128,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
@@ -201,6 +203,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
@@ -223,8 +226,12 @@ class TestMultiGPULlama:
]
)
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
)
def test_dpo_qlora_ddp(self, temp_dir):
@@ -275,6 +282,7 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
@@ -297,8 +305,12 @@ class TestMultiGPULlama:
]
)
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
)
@pytest.mark.parametrize(

View File

@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
load_model(cfg, tokenizer, inference=False)
@with_temp_dir
def test_mistral_multipack(self, temp_dir):

View File

@@ -3,8 +3,6 @@ import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
@@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_self_attn_is_patchable(),

View File

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -4,7 +4,8 @@ E2E tests for llama pretrain
import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -12,31 +13,40 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase):
class TestPretrainLlama:
"""
Test case for Llama models w pretraining
"""
@with_temp_dir
def test_pretrain_w_sample_packing(self, temp_dir):
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
@pytest.mark.parametrize(
"pretrain_multipack_attn",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
if not sample_packing and pretrain_multipack_attn:
return
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"sample_packing": sample_packing,
"pretrain_multipack_attn": pretrain_multipack_attn,
"dataset_processes": 1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"pretraining_dataset": [
{
@@ -47,7 +57,7 @@ class TestPretrainLlama(unittest.TestCase):
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
@@ -56,6 +66,7 @@ class TestPretrainLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
normalize_config(cfg)
@@ -64,3 +75,12 @@ class TestPretrainLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
loss_threshold = 3.5
if sample_packing and not pretrain_multipack_attn:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
)

View File

@@ -1,25 +0,0 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
)
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""
def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_step_is_patchable(),
"HF transformers Trainer.training_step has changed and isn't patchable",
)
def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_forward_is_patchable(),
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
)

View File

@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features

View File

@@ -12,7 +12,7 @@ from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates
assert (

View File

@@ -41,6 +41,7 @@ class TestPretrainingPacking(unittest.TestCase):
}
],
"sample_packing": True,
"pretrain_multipack_attn": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
@@ -87,9 +88,11 @@ class TestPretrainingPacking(unittest.TestCase):
assert data["labels"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
)
assert data["attention_mask"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
)
assert "attention_mask" not in data
# FIXME add back once we fix packing unpad/pad with attention mask
# assert data["attention_mask"].shape == torch.Size(
# [1, original_bsz * cfg.sequence_len]
# )
idx += 1