Compare commits

..

44 Commits

Author SHA1 Message Date
Dan Saunders
2daa94080c Merge branch 'main' into diff-transformer 2025-01-27 14:46:17 +00:00
Dan Saunders
0e9bfa6dee small fixes, improvements 2025-01-24 19:53:54 +00:00
Dan Saunders
ef38f10274 merging into main 2025-01-24 18:03:27 +00:00
Dan Saunders
66262c3092 moving out all diff attn code to plugin repo 2025-01-24 17:46:11 +00:00
Dan Saunders
016ba124e4 README update 2025-01-23 22:11:35 +00:00
Dan Saunders
7145d52d99 moving diff attn code to separate repo 2025-01-23 21:33:53 +00:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
23 changed files with 218 additions and 1049 deletions

3
.gitignore vendored
View File

@@ -186,3 +186,6 @@ out/
# vim # vim
*.swp *.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/

View File

@@ -1,6 +1,6 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
""" """
Evaluates a `transformers` model by first loading the dataset(s) specified in the Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, dataset_meta=dataset_meta) return evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -8,6 +8,7 @@ import click
import axolotl import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -222,6 +223,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main(): def main():
cli() cli()

View File

@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -157,6 +157,8 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
if isinstance(value, bool): if isinstance(value, bool):
if value: if value:
cmd.append(f"--{key}") cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else: else:
cmd.extend([f"--{key}", str(value)]) cmd.extend([f"--{key}", str(value)])

View File

@@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin. so it can't be used as a mixin.
""" """

View File

@@ -4,7 +4,7 @@ import csv
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset( def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]: ) -> Optional[dict[str, float]]:
"""Helper function to evaluate a single dataset safely. """Helper function to evaluate a single dataset safely.
Args: Args:
@@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets

View File

@@ -48,9 +48,9 @@ class BasePlugin:
Initializes the BasePlugin. Initializes the BasePlugin.
""" """
def register(self): # pylint: disable=unused-argument def register(self, cfg): # pylint: disable=unused-argument
""" """
Registers the plugin Registers the plugin with the given configuration.
Parameters: Parameters:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
@@ -274,7 +274,6 @@ class PluginManager:
try: try:
plugin = load_plugin(plugin_name) plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin self.plugins[plugin_name] = plugin
plugin.register()
except ImportError: except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}") logging.error(f"Failed to load plugin: {plugin_name}")

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args() input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = [] plugin_classes = []
dynamic_input = "" dynamic_input = ""
for plugin_args in input_args: for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1) plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n" dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls) plugin_classes.append(plugin_cls)
if dynamic_input: if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities" "AxolotlConfigWCapabilities"
] ]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -1,25 +0,0 @@
"""
Axolotl Plugin for Relaxed Recursive Transformers
"""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model
LOG = logging.getLogger(__name__)
class RelaxedRecursiveTransformerPlugin(BasePlugin):
"""
Plugin for Relaxed Recursive Transformers integration with Axolotl
"""
def get_input_args(self):
return "axolotl.integrations.rrt.args.RelaxedRecursiveTransformerArgs"
def register(self):
LOG.info(
"Registering Relaxed Recursive Transformers modeling with transformers"
)
register_rrt_model()

View File

@@ -1,11 +0,0 @@
"""
Axolotl config args for Relaxed Recursive Transformers plugin
"""
from pydantic import BaseModel
class RelaxedRecursiveTransformerArgs(BaseModel):
"""
Arguments pertaining to the Relaxed Recursive Transformer model.
"""

View File

@@ -1,370 +0,0 @@
"""
cli script for converting a pretrained model to a relaxed recursive transformer model
"""
import json
import logging
import math
import os
import re
from pathlib import Path
from typing import Tuple
import safetensors
import torch
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
)
logger = logging.getLogger(__name__)
def extract_layer_number(key):
"""Extract layer number from parameter key."""
match = re.search(r"layers\.(\d+)\.", key)
return int(match.group(1)) if match else None
def iter_parameter_weights(model_path, device="mps"):
"""
iterator over parameter weights in the model shards
:param model_path: Path to model shards
:param device: Computing device
:return: generator yielding (parameter key, parameter weight, layer index) tuples
"""
shards = list(model_path.glob("model*.safetensors"))
if not shards:
raise ValueError(f"No model shards found in {model_path}")
for shard in tqdm(shards, desc="Processing shards"):
with safetensors.safe_open(shard, framework="pt", device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
def iter_recursive_parameter_weights(
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
continue
recurse_idx = layer_idx % recurse_layers
suffix = f"{recurse_idx}.{matched_module_name}"
if rrt_avg_model_state_dict.get(suffix) is None:
# setup as storage for suffix with torch.stack
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
else:
rrt_avg_model_state_dict[suffix].append(
weight.to(torch.float32).detach().cpu()
)
for module_name in modules_to_recurse:
for recurse_idx in range(recurse_layers):
suffix = f"{recurse_idx}.{module_name}"
prefix = f"model.layers.{suffix}"
avg_weight = torch.stack(rrt_avg_model_state_dict[suffix]).mean(dim=0)
yield f"{prefix}.weight_base", avg_weight
# compute the decomposed lora diff from the weight base to the actual weight for each module
def low_rank_decomposition(
weight: torch.Tensor, max_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decompose a 2D matrix into low-rank matrices L and R using SVD.
:param weight: The matrix to decompose, of shape (H, W)
:param max_rank: The maximum rank of the decomposition
:return: A tuple of tensors (L, R)
"""
# pylint: disable=invalid-name
assert (
weight.dim() == 2
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
assert (
max_rank >= 1
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
dtype = weight.dtype
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
# Distribute S to both to improve numerical precision
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
return A.to(dtype), B.to(dtype)
def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
"""
Decompose the difference in directions (ΔV) via SVD,
and return (magnitudes, L, R).
"""
device = "cuda" if torch.cuda.is_available() else "mps"
# rslora
scaling = alpha / math.sqrt(rank)
base_weight = avg_weight.to(device)
final_weight = layer_weight.to(device)
delta_for_svd = final_weight - base_weight
# Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name
delta_for_svd, rank
)
if use_dora:
lora_weight = lora_B @ lora_A
weight_norm = get_weight_norm(
base_weight.to(lora_A.device), lora_weight, scaling
)
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
# let's rescale the lora weight to have the same magnitude as the base weight
return lora_A.cpu(), lora_B.cpu(), None
def iter_dora_parameter_weights(
model_path,
avg_recursive_weights,
modules_to_recurse: list[str],
alpha,
rank,
device="mps",
recurse_layers=12,
use_dora=True,
):
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
if "input_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = (
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
)
yield layernorm_key, weight
elif "post_attention_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}.weight"
yield layernorm_key, weight
else:
yield key, weight
continue
# figure out the base weight layer for this key
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
suffix = f"{layer_idx}.{matched_module_name}"
prefix = f"model.layers.{suffix}.weight_base"
avg_weight = avg_recursive_weights[prefix]
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
lora_magnitude_key = (
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
)
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
weight,
avg_weight,
alpha,
rank,
use_dora=use_dora,
)
yield lora_a_key, lora_a
yield lora_b_key, lora_b
if use_dora:
yield lora_magnitude_key, lora_magnitude
def save_state_dict_to_safetensors(state_dict, save_directory):
os.makedirs(save_directory, exist_ok=True)
weights_name = SAFE_WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
)
# pylint: disable=duplicate-code
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
del state_dict[tensor]
save_file(
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
)
del state_dict
if index is None:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
else:
save_index_file = SAFE_WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
def convert_llama_to_rrt(
model_name,
output_dir,
recurse_layers: int = 12,
rank=32,
alpha=32,
device=None,
use_dora=True,
):
if not device:
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.down_proj",
"mlp.gate_proj",
"mlp.up_proj",
]
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers % recurse_layers != 0:
raise ValueError(
f"The number of hidden layers ({num_hidden_layers}) in the model must be "
f"divisible by the recurse layers ({recurse_layers})"
)
config = RelaxedRecursiveLlamaConfig.from_dict(
{
**config.to_dict(),
"recurse_layers": recurse_layers,
"rank": rank,
"alpha": alpha,
"use_dora": use_dora,
}
)
config.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
logger.info("Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
logger.info("Calculating decomposed lora diff...")
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(
model_path,
rrt_model_state_dict,
modules_to_recurse,
alpha=32,
rank=rank,
device=device,
recurse_layers=recurse_layers,
use_dora=use_dora,
):
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# combine state dicts into a single state_dict
rrt_model_state_dict.update(rrt_lora_state_dict)
# save state dict as sharded safetensors to disk using split_torch_state_dict_into_shards
save_state_dict_to_safetensors(rrt_model_state_dict, output_dir)
if __name__ == "__main__":
# meta-llama/Llama-3.2-1B has 16 hidden layers
# meta-llama/Llama-3.2-3B has 28 hidden layers
convert_llama_to_rrt(
"meta-llama/Llama-3.2-3B",
"/tmp/rrt_model", # nosec
recurse_layers=4,
rank=256,
alpha=512,
use_dora=False,
)

View File

@@ -1,25 +0,0 @@
"""
module for modeling relaxed recursive transformers model
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
from .modeling_rrt_llama import (
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
def register_rrt_model():
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -1,16 +0,0 @@
"""
module for custom configuration for relaxed recursive transformers model
"""
from transformers import LlamaConfig
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
model_type: str = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True

View File

@@ -1,116 +0,0 @@
"""
module for the shared linear layer for the relaxed recursive transformers model
"""
import math
import torch
import torch.nn.functional as F
from peft.utils import transpose
from torch import nn
class RelaxedRecursiveDoraLinear(nn.Module):
"""
A single linear layer that is "shared" across multiple loop iterations,
but each iteration has its own DoRA offsets (A_i, B_i, magnitude_i).
The constructor expects you to specify:
- in_features, out_features
- B: number of loop iterations (i.e., how many times we "unroll")
- fan_in_fan_out: pass True if your underlying base weight is transposed, etc.
The forward(...) expects an additional argument "loop_idx" in [0..B-1],
which picks out the iteration-specific DoRA offsets.
"""
def __init__(
self,
in_features: int,
out_features: int,
B: int, # pylint: disable=invalid-name
rank: int,
alpha: int,
fan_in_fan_out: bool = False,
bias: bool = True,
use_dora: bool = True,
):
super().__init__()
self.B = B # pylint: disable=invalid-name
self.fan_in_fan_out = fan_in_fan_out
self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
self.use_bias = bias
if self.use_bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
)
self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
)
# rslora
self.scaling = alpha / math.sqrt(rank)
self.use_dora = use_dora
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList(
[nn.Parameter(torch.ones(out_features)) for _ in range(B)]
)
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = transpose(weight, self.fan_in_fan_out)
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
def forward(self, x, loop_idx: int):
"""
:param x: hidden state of shape (batch_size, seq_len, in_features)
:param loop_idx:
:return:
"""
eps = 1e-6
w_base = self.weight_base
w_base = w_base.to(x.dtype)
lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name
loop_idx
]
lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name
loop_idx
]
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
if self.use_dora:
x_eye: torch.Tensor = torch.eye(
lora_A.shape[1], device=lora_A.device, dtype=x.dtype
)
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
w_dora_norm: torch.Tensor = self.get_weight_norm(
w_base, w_dora_full.detach(), self.scaling
)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(
0
) # shape [1, out_features]
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora
# scale the lora norm to prevent gradient explosion
orig_norm = torch.linalg.norm(w_base)
update_norm = torch.linalg.norm(lora_out)
scale = orig_norm / (update_norm + eps)
return base_out + lora_out * scale

View File

@@ -1,471 +0,0 @@
import logging
from typing import Callable, Optional, Tuple, Union, Unpack
import torch
from torch import nn
from transformers import Cache, DynamicCache, LlamaConfig
from transformers.activations import ACT2FN
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__)
# pylint: skip-file
# mypy: ignore-errors
class RelaxedRecursiveLlamaMLP(nn.Module):
def __init__(self, config: RelaxedRecursiveLlamaConfig):
super().__init__()
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.up_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.down_proj = RelaxedRecursiveDoraLinear(
self.intermediate_size,
self.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx),
loop_idx,
)
return down_proj
class RelaxedRecursiveLlamaAttention(nn.Module):
"""
A single attention layer of the Relaxed Recursive Llama.
"""
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
super().__init__()
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
loop_idx: int,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = (
self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
logger.warning(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, loop_idx)
return attn_output, attn_weights # pylint: disable=return-value
class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
"""
A single layer of the Relaxed Recursive Llama decoder.
"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.hidden_size = config.hidden_size
self.self_attn = RelaxedRecursiveLlamaAttention(
config=config, layer_idx=layer_idx
)
self.mlp = RelaxedRecursiveLlamaMLP(config)
self.input_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
self.post_attention_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
def forward(
self,
hidden_states: torch.Tensor,
loop_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
loop_idx=loop_idx,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm_list[loop_idx](hidden_states)
hidden_states = self.mlp(hidden_states, loop_idx)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class RelaxedRecursiveLlamaModel(LlamaModel):
config_class = RelaxedRecursiveLlamaConfig
def __init__(self, config):
super(LlamaModel, self).__init__(config)
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
RelaxedRecursiveLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.recurse_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for loop_idx in range(self.recurse_loops):
for decoder_layer in self.layers[: self.config.recurse_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
loop_idx,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
loop_idx,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
config_class = RelaxedRecursiveLlamaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = RelaxedRecursiveLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_nb_trainable_parameters(self) -> tuple[int, int, int]:
r"""
Returns the number of trainable parameters and the number of all parameters in the model.
"""
trainable_params = 0
all_param = 0
lora_params = 0
for name, param in self.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes
# one needs to multiply the number of parameters by 2 to get
# the correct number of parameters
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "element_size"):
num_bytes = param.element_size()
elif not hasattr(param, "quant_storage"):
num_bytes = 1
else:
num_bytes = param.quant_storage.itemsize
num_params = num_params * 2 * num_bytes
all_param += num_params
if param.requires_grad:
trainable_params += num_params
if "lora_" in name:
lora_params += num_params
return trainable_params, all_param, lora_params

View File

@@ -812,6 +812,7 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.AutoModelLoader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,

157
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,157 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)]) result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called assert mock.called
assert mock.call_args.args[0] == [ assert mock.call_args.args[0][:5] == [
"accelerate", "accelerate",
"launch", "launch",
"-m", "-m",
f"axolotl.cli.{command}", f"axolotl.cli.{command}",
str(config_path), str(config_path),
"--debug-num-examples",
"0",
] ]
assert mock.call_args.kwargs == {"check": True} assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -23,6 +23,7 @@ def test_build_command():
"--batch-size", "--batch-size",
"8", "8",
"--debug", "--debug",
"--nouse-fp16",
] ]