Compare commits
44 Commits
no-zero-ds
...
diff-trans
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2daa94080c | ||
|
|
0e9bfa6dee | ||
|
|
ef38f10274 | ||
|
|
66262c3092 | ||
|
|
016ba124e4 | ||
|
|
7145d52d99 | ||
|
|
28694219a5 | ||
|
|
fd8ad6fcbf | ||
|
|
661d71a14b | ||
|
|
6dd47edcb8 | ||
|
|
7aca08ff60 | ||
|
|
4f804f6d88 | ||
|
|
443327c585 | ||
|
|
70c4e6fbe6 | ||
|
|
2a7f139ad2 | ||
|
|
332ce0ae85 | ||
|
|
e5fa842ff8 | ||
|
|
78e0ec0aa5 | ||
|
|
3bc568eb27 | ||
|
|
eb6611d55f | ||
|
|
4ff3328e66 | ||
|
|
a3fd5074a9 | ||
|
|
5b90da0be3 | ||
|
|
fcbfa86373 | ||
|
|
0d56582090 | ||
|
|
390cb5742e | ||
|
|
1d935f65c3 | ||
|
|
66176b3e07 | ||
|
|
505321ac95 | ||
|
|
0b382c88da | ||
|
|
ea07a7086e | ||
|
|
d22e1136bc | ||
|
|
63b8e42c6b | ||
|
|
bda1eed59e | ||
|
|
41ebd93158 | ||
|
|
4c050ce807 | ||
|
|
6665acf63d | ||
|
|
2f9fa4c465 | ||
|
|
849bc94112 | ||
|
|
e484ec778d | ||
|
|
df1504ae14 | ||
|
|
7be0d7496c | ||
|
|
13cdffa91f | ||
|
|
7a4b296f60 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -186,3 +186,6 @@ out/
|
|||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
# symlinked to axolotl-artifacts in docker containers
|
||||||
|
outputs
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ set -e
|
|||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
modal application to run axolotl gpu tests in Modal
|
modal application to run axolotl gpu tests in Modal
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
||||||
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
return evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import click
|
|||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
|
from axolotl.cli.plugins import setup_plugin_commands
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
add_options_from_dataclass,
|
add_options_from_dataclass,
|
||||||
@@ -222,6 +223,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
setup_plugin_commands(cli)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|||||||
36
src/axolotl/cli/plugins.py
Normal file
36
src/axolotl/cli/plugins.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Module for adding click CLI commands from axolotl plugins."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_plugin_commands(cli: click.core.Group) -> None:
|
||||||
|
"""
|
||||||
|
Setup CLI commands for available plugins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cli: Click CLI object to add plugin CLI options to.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from axolotl_diff_transformer.convert_diff_transformer import do_cli
|
||||||
|
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def convert_diff_transformer(config: str, **kwargs):
|
||||||
|
"""Convert model attention layers to differential attention layers."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
except ImportError as exc:
|
||||||
|
LOG.debug("axolotl-diff-transformer not found: %s", exc)
|
||||||
@@ -157,6 +157,8 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
if value:
|
if value:
|
||||||
cmd.append(f"--{key}")
|
cmd.append(f"--{key}")
|
||||||
|
else:
|
||||||
|
cmd.append(f"--no{key}")
|
||||||
else:
|
else:
|
||||||
cmd.extend([f"--{key}", str(value)])
|
cmd.extend([f"--{key}", str(value)])
|
||||||
|
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
|
||||||
so it can't be used as a mixin.
|
so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import csv
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset safely.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -61,7 +61,7 @@ def evaluate_dataset(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets
|
||||||
|
|
||||||
|
|||||||
@@ -43,10 +43,12 @@ def merge_input_args():
|
|||||||
input_args: List[str] = plugin_manager.get_input_args()
|
input_args: List[str] = plugin_manager.get_input_args()
|
||||||
plugin_classes = []
|
plugin_classes = []
|
||||||
dynamic_input = ""
|
dynamic_input = ""
|
||||||
|
|
||||||
for plugin_args in input_args:
|
for plugin_args in input_args:
|
||||||
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||||
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||||
plugin_classes.append(plugin_cls)
|
plugin_classes.append(plugin_cls)
|
||||||
|
|
||||||
if dynamic_input:
|
if dynamic_input:
|
||||||
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
|
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||||
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||||
@@ -62,4 +64,5 @@ def merge_input_args():
|
|||||||
"AxolotlConfigWCapabilities"
|
"AxolotlConfigWCapabilities"
|
||||||
]
|
]
|
||||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||||
|
|
||||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||||
|
|||||||
@@ -812,6 +812,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
|
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
|
|||||||
157
src/axolotl/utils/yaml.py
Normal file
157
src/axolotl/utils/yaml.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Utilities for YAML files."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Dict, List, Set, Tuple, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class YAMLOrderTracker:
|
||||||
|
"""Tracks the order of keys and section breaks in YAML files."""
|
||||||
|
|
||||||
|
def __init__(self, yaml_path: str):
|
||||||
|
self.yaml_path = yaml_path
|
||||||
|
self.structure, self.needs_break = self._parse_yaml_structure()
|
||||||
|
|
||||||
|
def _get_indentation_level(self, line: str) -> int:
|
||||||
|
"""Get the indentation level of a line."""
|
||||||
|
return len(line) - len(line.lstrip())
|
||||||
|
|
||||||
|
def _parse_yaml_structure(
|
||||||
|
self,
|
||||||
|
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
|
||||||
|
"""Parse the YAML file to extract structure and identify section breaks."""
|
||||||
|
with open(self.yaml_path, "r", encoding="utf-8") as file:
|
||||||
|
contents = file.readlines()
|
||||||
|
|
||||||
|
structure: OrderedDict = OrderedDict()
|
||||||
|
needs_break = set() # Track which keys should have a break before them
|
||||||
|
current_path = []
|
||||||
|
last_indentation = -1
|
||||||
|
had_empty_line = False
|
||||||
|
|
||||||
|
for line in contents:
|
||||||
|
# Track empty lines and comments
|
||||||
|
if not line.strip() or line.strip().startswith("#"):
|
||||||
|
had_empty_line = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get indentation level and content
|
||||||
|
indentation = self._get_indentation_level(line)
|
||||||
|
content = line.strip()
|
||||||
|
|
||||||
|
# Skip lines that don't define keys
|
||||||
|
if ":" not in content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract key
|
||||||
|
key = content.split(":")[0].strip()
|
||||||
|
|
||||||
|
# If this is a top-level key and we had an empty line, mark it
|
||||||
|
if indentation == 0:
|
||||||
|
if had_empty_line:
|
||||||
|
needs_break.add(key)
|
||||||
|
had_empty_line = False
|
||||||
|
|
||||||
|
# Handle indentation changes
|
||||||
|
if indentation > last_indentation:
|
||||||
|
current_path.append(key)
|
||||||
|
elif indentation < last_indentation:
|
||||||
|
levels_up = (last_indentation - indentation) // 2
|
||||||
|
current_path = current_path[:-levels_up]
|
||||||
|
current_path[-1] = key
|
||||||
|
else:
|
||||||
|
if current_path:
|
||||||
|
current_path[-1] = key
|
||||||
|
|
||||||
|
# Update structure
|
||||||
|
current_dict = structure
|
||||||
|
for path_key in current_path[:-1]:
|
||||||
|
if path_key not in current_dict:
|
||||||
|
current_dict[path_key] = OrderedDict()
|
||||||
|
current_dict = current_dict[path_key]
|
||||||
|
|
||||||
|
if current_path:
|
||||||
|
if current_path[-1] not in current_dict:
|
||||||
|
current_dict[current_path[-1]] = OrderedDict()
|
||||||
|
|
||||||
|
last_indentation = indentation
|
||||||
|
|
||||||
|
return structure, needs_break
|
||||||
|
|
||||||
|
|
||||||
|
class OrderedDumper(yaml.SafeDumper):
|
||||||
|
"""Custom YAML dumper that maintains dictionary order."""
|
||||||
|
|
||||||
|
|
||||||
|
def represent_none(self, _):
|
||||||
|
"""Represent None values as empty fields."""
|
||||||
|
return self.represent_scalar("tag:yaml.org,2002:null", "")
|
||||||
|
|
||||||
|
|
||||||
|
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
|
||||||
|
"""Custom representer for dictionaries that maintains order."""
|
||||||
|
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
||||||
|
|
||||||
|
|
||||||
|
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
|
||||||
|
"""Reorder a dictionary based on a reference structure."""
|
||||||
|
ordered = OrderedDict()
|
||||||
|
|
||||||
|
# First add keys that are in the reference order
|
||||||
|
for key in reference_structure:
|
||||||
|
if key in data:
|
||||||
|
if isinstance(reference_structure[key], dict) and isinstance(
|
||||||
|
data[key], dict
|
||||||
|
):
|
||||||
|
ordered[key] = reorder_dict(data[key], reference_structure[key])
|
||||||
|
else:
|
||||||
|
ordered[key] = data[key]
|
||||||
|
|
||||||
|
# Then add any remaining keys that weren't in the reference
|
||||||
|
for key in data:
|
||||||
|
if key not in ordered:
|
||||||
|
ordered[key] = data[key]
|
||||||
|
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
|
||||||
|
def dump_yaml_preserved_order(
|
||||||
|
data: Dict, reference_yaml_path: str, output_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Dump YAML file while preserving nested order and normalized spacing."""
|
||||||
|
# Get reference structure and spacing
|
||||||
|
tracker = YAMLOrderTracker(reference_yaml_path)
|
||||||
|
|
||||||
|
# Reorder the data
|
||||||
|
ordered_data = reorder_dict(data, tracker.structure)
|
||||||
|
|
||||||
|
# Register the custom representers
|
||||||
|
OrderedDumper.add_representer(type(None), represent_none)
|
||||||
|
OrderedDumper.add_representer(dict, ordered_dict_representer)
|
||||||
|
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
|
||||||
|
|
||||||
|
# First dump to string
|
||||||
|
yaml_str = yaml.dump(
|
||||||
|
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add spacing according to reference
|
||||||
|
lines = yaml_str.split("\n")
|
||||||
|
result_lines: List[str] = []
|
||||||
|
current_line = 0
|
||||||
|
|
||||||
|
while current_line < len(lines):
|
||||||
|
line = lines[current_line]
|
||||||
|
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
|
||||||
|
key = line.split(":")[0].strip()
|
||||||
|
if key in tracker.needs_break:
|
||||||
|
# Add single empty line before this key
|
||||||
|
if result_lines and result_lines[-1] != "":
|
||||||
|
result_lines.append("")
|
||||||
|
result_lines.append(line)
|
||||||
|
current_line += 1
|
||||||
|
|
||||||
|
# Write the final result
|
||||||
|
with open(output_path, "w", encoding="utf-8") as file:
|
||||||
|
file.write("\n".join(result_lines))
|
||||||
@@ -43,14 +43,12 @@ class BaseCliTest:
|
|||||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||||
|
|
||||||
assert mock.called
|
assert mock.called
|
||||||
assert mock.call_args.args[0] == [
|
assert mock.call_args.args[0][:5] == [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"launch",
|
"launch",
|
||||||
"-m",
|
"-m",
|
||||||
f"axolotl.cli.{command}",
|
f"axolotl.cli.{command}",
|
||||||
str(config_path),
|
str(config_path),
|
||||||
"--debug-num-examples",
|
|
||||||
"0",
|
|
||||||
]
|
]
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def test_build_command():
|
|||||||
"--batch-size",
|
"--batch-size",
|
||||||
"8",
|
"8",
|
||||||
"--debug",
|
"--debug",
|
||||||
|
"--nouse-fp16",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user