Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
1cfb8feb2d add iterable argument to preprocess-cli 2025-01-27 14:31:12 -05:00
15 changed files with 18 additions and 216 deletions

3
.gitignore vendored
View File

@@ -186,6 +186,3 @@ out/
# vim # vim
*.swp *.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -4,6 +4,7 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/

View File

@@ -1,6 +1,6 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -13,6 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1) debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True) download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass @dataclass

View File

@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
Evaluates a `transformers` model by first loading the dataset(s) specified in the Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
return evaluate(cfg=cfg, dataset_meta=dataset_meta) evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -8,7 +8,6 @@ import click
import axolotl import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -223,9 +222,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main(): def main():
cli() cli()

View File

@@ -1,36 +0,0 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -157,8 +157,6 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
if isinstance(value, bool): if isinstance(value, bool):
if value: if value:
cmd.append(f"--{key}") cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else: else:
cmd.extend([f"--{key}", str(value)]) cmd.extend([f"--{key}", str(value)])

View File

@@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin. so it can't be used as a mixin.
""" """

View File

@@ -4,7 +4,7 @@ import csv
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset( def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[dict[str, float]]: ) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset safely. """Helper function to evaluate a single dataset safely.
Args: Args:
@@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets

View File

@@ -43,12 +43,10 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args() input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = [] plugin_classes = []
dynamic_input = "" dynamic_input = ""
for plugin_args in input_args: for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1) plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n" dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls) plugin_classes.append(plugin_cls)
if dynamic_input: if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -64,5 +62,4 @@ def merge_input_args():
"AxolotlConfigWCapabilities" "AxolotlConfigWCapabilities"
] ]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -812,7 +812,6 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.AutoModelLoader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,

View File

@@ -1,157 +0,0 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -43,12 +43,14 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)]) result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called assert mock.called
assert mock.call_args.args[0][:5] == [ assert mock.call_args.args[0] == [
"accelerate", "accelerate",
"launch", "launch",
"-m", "-m",
f"axolotl.cli.{command}", f"axolotl.cli.{command}",
str(config_path), str(config_path),
"--debug-num-examples",
"0",
] ]
assert mock.call_args.kwargs == {"check": True} assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -23,7 +23,6 @@ def test_build_command():
"--batch-size", "--batch-size",
"8", "8",
"--debug", "--debug",
"--nouse-fp16",
] ]