Compare commits

..

2 Commits

Author SHA1 Message Date
Dan Saunders
4d1553e53f updates 2025-01-27 15:43:51 -05:00
Dan Saunders
f866157b74 initial quartodoc changes 2025-01-27 18:57:45 +00:00
26 changed files with 177 additions and 243 deletions

3
.gitignore vendored
View File

@@ -186,6 +186,3 @@ out/
# vim # vim
*.swp *.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -27,7 +27,6 @@ website:
href: index.qmd href: index.qmd
- section: "How-To Guides" - section: "How-To Guides"
contents: contents:
# TODO Edit folder structure after we have more docs.
- docs/debugging.qmd - docs/debugging.qmd
- docs/multipack.qmd - docs/multipack.qmd
- docs/fsdp_qlora.qmd - docs/fsdp_qlora.qmd
@@ -43,11 +42,24 @@ website:
- section: "Reference" - section: "Reference"
contents: contents:
- docs/config.qmd - docs/config.qmd
- docs/faq.qmd - section: "API Reference"
contents: "{{ api_contents }}"
- text: "FAQ"
href: docs/faq.qmd
format: format:
html: html:
theme: materia theme: materia
css: styles.css css: styles.css
toc: true toc: true
quartodoc:
package: axolotl
parser: google
dir: api
sections:
- title: Core API
desc: Core functionality of Axolotl
metadata-files:
- api/_sidebar.yml

17
_sidebar.yml Normal file
View File

@@ -0,0 +1,17 @@
website:
sidebar:
- collapse-level: 2
contents:
- href: introduction.qmd
text: Introduction
- contents:
- reference/index.qmd
- contents: []
section: axolotl
section: Reference
- href: basics-summary.qmd
text: Basics
id: reference
search: true
style: docked
- id: dummy-sidebar

View File

@@ -0,0 +1,11 @@
# ConstantLengthDataset { #axolotl.ConstantLengthDataset }
```python
ConstantLengthDataset(self, tokenizer, datasets, seq_length=2048)
```
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.

View File

@@ -0,0 +1,19 @@
# TokenizedPromptDataset { #axolotl.TokenizedPromptDataset }
```python
TokenizedPromptDataset(
self,
prompt_tokenizer,
dataset,
process_count=None,
keep_in_memory=False,
**kwargs,
)
```
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.

28
api/choose_config.qmd Normal file
View File

@@ -0,0 +1,28 @@
# choose_config { #axolotl.choose_config }
```python
choose_config(path)
```
Helper method for choosing a `axolotl` config YAML file (considering only files
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
`path`, the user is prompted to choose one.
## Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|--------|--------|-----------------------------------------------|------------|
| path | Path | Directory in which config file(s) are stored. | _required_ |
## Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|--------|----------------------------------------------------------------------------------|
| | str | Path to either (1) the sole YAML file, or (2) if more than one YAML files exist, |
| | str | the user-selected YAML file. |
## Raises {.doc-section .doc-section-raises}
| Name | Type | Description |
|--------|------------|-------------------------------------------------|
| | ValueError | If no YAML files are found in the given `path`. |

5
api/index.qmd Normal file
View File

@@ -0,0 +1,5 @@
# Function reference {.doc .doc-index}
## Core API
Core functionality of Axolotl

21
api/load_cfg.qmd Normal file
View File

@@ -0,0 +1,21 @@
# load_cfg { #axolotl.load_cfg }
```python
load_cfg(config=Path('examples/'), **kwargs)
```
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
## Parameters {.doc-section .doc-section-parameters}
| Name | Type | Description | Default |
|--------|--------------------|--------------------------------------------------------------|---------------------|
| config | Union\[str, Path\] | Path (local or remote) to `axolotl` config YAML file. | `Path('examples/')` |
| kwargs | | Additional keyword arguments to override config file values. | `{}` |
## Returns {.doc-section .doc-section-returns}
| Name | Type | Description |
|--------|-------------|-----------------------------------------------------|
| | DictDefault | `DictDefault` mapping configuration keys to values. |

5
api/validate_config.qmd Normal file
View File

@@ -0,0 +1,5 @@
# validate_config { #axolotl.validate_config }
```python
validate_config(cfg, capabilities=None, env_capabilities=None)
```

View File

@@ -4,6 +4,7 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/

1
objects.json Normal file
View File

@@ -0,0 +1 @@
{"project": "axolotl", "version": "0.0.9999", "count": 0, "items": []}

3
reference/index.qmd Normal file
View File

@@ -0,0 +1,3 @@
# API Reference {.doc .doc-index}
## Core API

View File

@@ -2,3 +2,5 @@ pre-commit
black black
mypy mypy
types-requests types-requests
quartodoc
quarto-cli

View File

@@ -2,6 +2,20 @@
import pkgutil import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package from .cli.config import choose_config, load_cfg, validate_config
from .datasets import ConstantLengthDataset, TokenizedPromptDataset
from .evaluate import evaluate
from .train import train
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.6.0" __version__ = "0.6.0"
__all__ = [
"train",
"evaluate",
"TokenizedPromptDataset",
"ConstantLengthDataset",
"load_cfg",
"choose_config",
"validate_config",
]

View File

@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
Evaluates a `transformers` model by first loading the dataset(s) specified in the Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
return evaluate(cfg=cfg, dataset_meta=dataset_meta) evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -8,7 +8,6 @@ import click
import axolotl import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -223,9 +222,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main(): def main():
cli() cli()

View File

@@ -1,36 +0,0 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -157,8 +157,6 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
if isinstance(value, bool): if isinstance(value, bool):
if value: if value:
cmd.append(f"--{key}") cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else: else:
cmd.extend([f"--{key}", str(value)]) cmd.extend([f"--{key}", str(value)])

View File

@@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin. so it can't be used as a mixin.
""" """

View File

@@ -4,7 +4,7 @@ import csv
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset( def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[dict[str, float]]: ) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset safely. """Helper function to evaluate a single dataset safely.
Args: Args:
@@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets

View File

@@ -43,12 +43,10 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args() input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = [] plugin_classes = []
dynamic_input = "" dynamic_input = ""
for plugin_args in input_args: for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1) plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n" dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls) plugin_classes.append(plugin_cls)
if dynamic_input: if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -64,5 +62,4 @@ def merge_input_args():
"AxolotlConfigWCapabilities" "AxolotlConfigWCapabilities"
] ]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -812,7 +812,6 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.AutoModelLoader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,

View File

@@ -1,157 +0,0 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -43,12 +43,14 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)]) result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called assert mock.called
assert mock.call_args.args[0][:5] == [ assert mock.call_args.args[0] == [
"accelerate", "accelerate",
"launch", "launch",
"-m", "-m",
f"axolotl.cli.{command}", f"axolotl.cli.{command}",
str(config_path), str(config_path),
"--debug-num-examples",
"0",
] ]
assert mock.call_args.kwargs == {"check": True} assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -23,7 +23,6 @@ def test_build_command():
"--batch-size", "--batch-size",
"8", "8",
"--debug", "--debug",
"--nouse-fp16",
] ]