Compare commits
2 Commits
diff-trans
...
autodoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d1553e53f | ||
|
|
f866157b74 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -186,6 +186,3 @@ out/
|
|||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
# symlinked to axolotl-artifacts in docker containers
|
|
||||||
outputs
|
|
||||||
|
|||||||
18
_quarto.yml
18
_quarto.yml
@@ -27,7 +27,6 @@ website:
|
|||||||
href: index.qmd
|
href: index.qmd
|
||||||
- section: "How-To Guides"
|
- section: "How-To Guides"
|
||||||
contents:
|
contents:
|
||||||
# TODO Edit folder structure after we have more docs.
|
|
||||||
- docs/debugging.qmd
|
- docs/debugging.qmd
|
||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/fsdp_qlora.qmd
|
- docs/fsdp_qlora.qmd
|
||||||
@@ -43,11 +42,24 @@ website:
|
|||||||
- section: "Reference"
|
- section: "Reference"
|
||||||
contents:
|
contents:
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
- docs/faq.qmd
|
- section: "API Reference"
|
||||||
|
contents: "{{ api_contents }}"
|
||||||
|
- text: "FAQ"
|
||||||
|
href: docs/faq.qmd
|
||||||
|
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
theme: materia
|
theme: materia
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
|
|
||||||
|
quartodoc:
|
||||||
|
package: axolotl
|
||||||
|
parser: google
|
||||||
|
dir: api
|
||||||
|
sections:
|
||||||
|
- title: Core API
|
||||||
|
desc: Core functionality of Axolotl
|
||||||
|
|
||||||
|
metadata-files:
|
||||||
|
- api/_sidebar.yml
|
||||||
|
|||||||
17
_sidebar.yml
Normal file
17
_sidebar.yml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
website:
|
||||||
|
sidebar:
|
||||||
|
- collapse-level: 2
|
||||||
|
contents:
|
||||||
|
- href: introduction.qmd
|
||||||
|
text: Introduction
|
||||||
|
- contents:
|
||||||
|
- reference/index.qmd
|
||||||
|
- contents: []
|
||||||
|
section: axolotl
|
||||||
|
section: Reference
|
||||||
|
- href: basics-summary.qmd
|
||||||
|
text: Basics
|
||||||
|
id: reference
|
||||||
|
search: true
|
||||||
|
style: docked
|
||||||
|
- id: dummy-sidebar
|
||||||
11
api/ConstantLengthDataset.qmd
Normal file
11
api/ConstantLengthDataset.qmd
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# ConstantLengthDataset { #axolotl.ConstantLengthDataset }
|
||||||
|
|
||||||
|
```python
|
||||||
|
ConstantLengthDataset(self, tokenizer, datasets, seq_length=2048)
|
||||||
|
```
|
||||||
|
|
||||||
|
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||||
|
Args:
|
||||||
|
tokenizer (Tokenizer): The processor used for processing the data.
|
||||||
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
|
seq_length (int): Length of token sequences to return.
|
||||||
19
api/TokenizedPromptDataset.qmd
Normal file
19
api/TokenizedPromptDataset.qmd
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# TokenizedPromptDataset { #axolotl.TokenizedPromptDataset }
|
||||||
|
|
||||||
|
```python
|
||||||
|
TokenizedPromptDataset(
|
||||||
|
self,
|
||||||
|
prompt_tokenizer,
|
||||||
|
dataset,
|
||||||
|
process_count=None,
|
||||||
|
keep_in_memory=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Dataset that returns tokenized prompts from a stream of text files.
|
||||||
|
Args:
|
||||||
|
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||||
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
|
process_count (int): Number of processes to use for tokenizing.
|
||||||
|
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
|
||||||
28
api/choose_config.qmd
Normal file
28
api/choose_config.qmd
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# choose_config { #axolotl.choose_config }
|
||||||
|
|
||||||
|
```python
|
||||||
|
choose_config(path)
|
||||||
|
```
|
||||||
|
|
||||||
|
Helper method for choosing a `axolotl` config YAML file (considering only files
|
||||||
|
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
|
||||||
|
`path`, the user is prompted to choose one.
|
||||||
|
|
||||||
|
## Parameters {.doc-section .doc-section-parameters}
|
||||||
|
|
||||||
|
| Name | Type | Description | Default |
|
||||||
|
|--------|--------|-----------------------------------------------|------------|
|
||||||
|
| path | Path | Directory in which config file(s) are stored. | _required_ |
|
||||||
|
|
||||||
|
## Returns {.doc-section .doc-section-returns}
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|--------|--------|----------------------------------------------------------------------------------|
|
||||||
|
| | str | Path to either (1) the sole YAML file, or (2) if more than one YAML files exist, |
|
||||||
|
| | str | the user-selected YAML file. |
|
||||||
|
|
||||||
|
## Raises {.doc-section .doc-section-raises}
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|--------|------------|-------------------------------------------------|
|
||||||
|
| | ValueError | If no YAML files are found in the given `path`. |
|
||||||
5
api/index.qmd
Normal file
5
api/index.qmd
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Function reference {.doc .doc-index}
|
||||||
|
|
||||||
|
## Core API
|
||||||
|
|
||||||
|
Core functionality of Axolotl
|
||||||
21
api/load_cfg.qmd
Normal file
21
api/load_cfg.qmd
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# load_cfg { #axolotl.load_cfg }
|
||||||
|
|
||||||
|
```python
|
||||||
|
load_cfg(config=Path('examples/'), **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
|
various setup.
|
||||||
|
|
||||||
|
## Parameters {.doc-section .doc-section-parameters}
|
||||||
|
|
||||||
|
| Name | Type | Description | Default |
|
||||||
|
|--------|--------------------|--------------------------------------------------------------|---------------------|
|
||||||
|
| config | Union\[str, Path\] | Path (local or remote) to `axolotl` config YAML file. | `Path('examples/')` |
|
||||||
|
| kwargs | | Additional keyword arguments to override config file values. | `{}` |
|
||||||
|
|
||||||
|
## Returns {.doc-section .doc-section-returns}
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
|--------|-------------|-----------------------------------------------------|
|
||||||
|
| | DictDefault | `DictDefault` mapping configuration keys to values. |
|
||||||
5
api/validate_config.qmd
Normal file
5
api/validate_config.qmd
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# validate_config { #axolotl.validate_config }
|
||||||
|
|
||||||
|
```python
|
||||||
|
validate_config(cfg, capabilities=None, env_capabilities=None)
|
||||||
|
```
|
||||||
@@ -4,6 +4,7 @@ set -e
|
|||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
|
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
|
|||||||
1
objects.json
Normal file
1
objects.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"project": "axolotl", "version": "0.0.9999", "count": 0, "items": []}
|
||||||
3
reference/index.qmd
Normal file
3
reference/index.qmd
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# API Reference {.doc .doc-index}
|
||||||
|
|
||||||
|
## Core API
|
||||||
@@ -2,3 +2,5 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
quartodoc
|
||||||
|
quarto-cli
|
||||||
|
|||||||
@@ -2,6 +2,20 @@
|
|||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
from .cli.config import choose_config, load_cfg, validate_config
|
||||||
|
from .datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||||
|
from .evaluate import evaluate
|
||||||
|
from .train import train
|
||||||
|
|
||||||
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.6.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"train",
|
||||||
|
"evaluate",
|
||||||
|
"TokenizedPromptDataset",
|
||||||
|
"ConstantLengthDataset",
|
||||||
|
"load_cfg",
|
||||||
|
"choose_config",
|
||||||
|
"validate_config",
|
||||||
|
]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
|
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
||||||
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
return evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import click
|
|||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.cli.plugins import setup_plugin_commands
|
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
add_options_from_dataclass,
|
add_options_from_dataclass,
|
||||||
@@ -223,9 +222,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
setup_plugin_commands(cli)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
"""Module for adding click CLI commands from axolotl plugins."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import click
|
|
||||||
|
|
||||||
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_plugin_commands(cli: click.core.Group) -> None:
|
|
||||||
"""
|
|
||||||
Setup CLI commands for available plugins.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cli: Click CLI object to add plugin CLI options to.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from axolotl_diff_transformer.convert_diff_transformer import do_cli
|
|
||||||
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
|
||||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
|
||||||
def convert_diff_transformer(config: str, **kwargs):
|
|
||||||
"""Convert model attention layers to differential attention layers."""
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
do_cli(config=config, **kwargs)
|
|
||||||
|
|
||||||
except ImportError as exc:
|
|
||||||
LOG.debug("axolotl-diff-transformer not found: %s", exc)
|
|
||||||
@@ -157,8 +157,6 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
if value:
|
if value:
|
||||||
cmd.append(f"--{key}")
|
cmd.append(f"--{key}")
|
||||||
else:
|
|
||||||
cmd.append(f"--no{key}")
|
|
||||||
else:
|
else:
|
||||||
cmd.extend([f"--{key}", str(value)])
|
cmd.extend([f"--{key}", str(value)])
|
||||||
|
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
so it can't be used as a mixin.
|
so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import csv
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset safely.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -61,7 +61,7 @@ def evaluate_dataset(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets
|
||||||
|
|
||||||
|
|||||||
@@ -43,12 +43,10 @@ def merge_input_args():
|
|||||||
input_args: List[str] = plugin_manager.get_input_args()
|
input_args: List[str] = plugin_manager.get_input_args()
|
||||||
plugin_classes = []
|
plugin_classes = []
|
||||||
dynamic_input = ""
|
dynamic_input = ""
|
||||||
|
|
||||||
for plugin_args in input_args:
|
for plugin_args in input_args:
|
||||||
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||||
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||||
plugin_classes.append(plugin_cls)
|
plugin_classes.append(plugin_cls)
|
||||||
|
|
||||||
if dynamic_input:
|
if dynamic_input:
|
||||||
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
|
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||||
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
|
||||||
@@ -64,5 +62,4 @@ def merge_input_args():
|
|||||||
"AxolotlConfigWCapabilities"
|
"AxolotlConfigWCapabilities"
|
||||||
]
|
]
|
||||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||||
|
|
||||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||||
|
|||||||
@@ -812,7 +812,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
|
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
|
|||||||
@@ -1,157 +0,0 @@
|
|||||||
"""Utilities for YAML files."""
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any, Dict, List, Set, Tuple, Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
class YAMLOrderTracker:
|
|
||||||
"""Tracks the order of keys and section breaks in YAML files."""
|
|
||||||
|
|
||||||
def __init__(self, yaml_path: str):
|
|
||||||
self.yaml_path = yaml_path
|
|
||||||
self.structure, self.needs_break = self._parse_yaml_structure()
|
|
||||||
|
|
||||||
def _get_indentation_level(self, line: str) -> int:
|
|
||||||
"""Get the indentation level of a line."""
|
|
||||||
return len(line) - len(line.lstrip())
|
|
||||||
|
|
||||||
def _parse_yaml_structure(
|
|
||||||
self,
|
|
||||||
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
|
|
||||||
"""Parse the YAML file to extract structure and identify section breaks."""
|
|
||||||
with open(self.yaml_path, "r", encoding="utf-8") as file:
|
|
||||||
contents = file.readlines()
|
|
||||||
|
|
||||||
structure: OrderedDict = OrderedDict()
|
|
||||||
needs_break = set() # Track which keys should have a break before them
|
|
||||||
current_path = []
|
|
||||||
last_indentation = -1
|
|
||||||
had_empty_line = False
|
|
||||||
|
|
||||||
for line in contents:
|
|
||||||
# Track empty lines and comments
|
|
||||||
if not line.strip() or line.strip().startswith("#"):
|
|
||||||
had_empty_line = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get indentation level and content
|
|
||||||
indentation = self._get_indentation_level(line)
|
|
||||||
content = line.strip()
|
|
||||||
|
|
||||||
# Skip lines that don't define keys
|
|
||||||
if ":" not in content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract key
|
|
||||||
key = content.split(":")[0].strip()
|
|
||||||
|
|
||||||
# If this is a top-level key and we had an empty line, mark it
|
|
||||||
if indentation == 0:
|
|
||||||
if had_empty_line:
|
|
||||||
needs_break.add(key)
|
|
||||||
had_empty_line = False
|
|
||||||
|
|
||||||
# Handle indentation changes
|
|
||||||
if indentation > last_indentation:
|
|
||||||
current_path.append(key)
|
|
||||||
elif indentation < last_indentation:
|
|
||||||
levels_up = (last_indentation - indentation) // 2
|
|
||||||
current_path = current_path[:-levels_up]
|
|
||||||
current_path[-1] = key
|
|
||||||
else:
|
|
||||||
if current_path:
|
|
||||||
current_path[-1] = key
|
|
||||||
|
|
||||||
# Update structure
|
|
||||||
current_dict = structure
|
|
||||||
for path_key in current_path[:-1]:
|
|
||||||
if path_key not in current_dict:
|
|
||||||
current_dict[path_key] = OrderedDict()
|
|
||||||
current_dict = current_dict[path_key]
|
|
||||||
|
|
||||||
if current_path:
|
|
||||||
if current_path[-1] not in current_dict:
|
|
||||||
current_dict[current_path[-1]] = OrderedDict()
|
|
||||||
|
|
||||||
last_indentation = indentation
|
|
||||||
|
|
||||||
return structure, needs_break
|
|
||||||
|
|
||||||
|
|
||||||
class OrderedDumper(yaml.SafeDumper):
|
|
||||||
"""Custom YAML dumper that maintains dictionary order."""
|
|
||||||
|
|
||||||
|
|
||||||
def represent_none(self, _):
|
|
||||||
"""Represent None values as empty fields."""
|
|
||||||
return self.represent_scalar("tag:yaml.org,2002:null", "")
|
|
||||||
|
|
||||||
|
|
||||||
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
|
|
||||||
"""Custom representer for dictionaries that maintains order."""
|
|
||||||
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
|
||||||
|
|
||||||
|
|
||||||
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
|
|
||||||
"""Reorder a dictionary based on a reference structure."""
|
|
||||||
ordered = OrderedDict()
|
|
||||||
|
|
||||||
# First add keys that are in the reference order
|
|
||||||
for key in reference_structure:
|
|
||||||
if key in data:
|
|
||||||
if isinstance(reference_structure[key], dict) and isinstance(
|
|
||||||
data[key], dict
|
|
||||||
):
|
|
||||||
ordered[key] = reorder_dict(data[key], reference_structure[key])
|
|
||||||
else:
|
|
||||||
ordered[key] = data[key]
|
|
||||||
|
|
||||||
# Then add any remaining keys that weren't in the reference
|
|
||||||
for key in data:
|
|
||||||
if key not in ordered:
|
|
||||||
ordered[key] = data[key]
|
|
||||||
|
|
||||||
return ordered
|
|
||||||
|
|
||||||
|
|
||||||
def dump_yaml_preserved_order(
|
|
||||||
data: Dict, reference_yaml_path: str, output_path: str
|
|
||||||
) -> None:
|
|
||||||
"""Dump YAML file while preserving nested order and normalized spacing."""
|
|
||||||
# Get reference structure and spacing
|
|
||||||
tracker = YAMLOrderTracker(reference_yaml_path)
|
|
||||||
|
|
||||||
# Reorder the data
|
|
||||||
ordered_data = reorder_dict(data, tracker.structure)
|
|
||||||
|
|
||||||
# Register the custom representers
|
|
||||||
OrderedDumper.add_representer(type(None), represent_none)
|
|
||||||
OrderedDumper.add_representer(dict, ordered_dict_representer)
|
|
||||||
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
|
|
||||||
|
|
||||||
# First dump to string
|
|
||||||
yaml_str = yaml.dump(
|
|
||||||
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add spacing according to reference
|
|
||||||
lines = yaml_str.split("\n")
|
|
||||||
result_lines: List[str] = []
|
|
||||||
current_line = 0
|
|
||||||
|
|
||||||
while current_line < len(lines):
|
|
||||||
line = lines[current_line]
|
|
||||||
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
|
|
||||||
key = line.split(":")[0].strip()
|
|
||||||
if key in tracker.needs_break:
|
|
||||||
# Add single empty line before this key
|
|
||||||
if result_lines and result_lines[-1] != "":
|
|
||||||
result_lines.append("")
|
|
||||||
result_lines.append(line)
|
|
||||||
current_line += 1
|
|
||||||
|
|
||||||
# Write the final result
|
|
||||||
with open(output_path, "w", encoding="utf-8") as file:
|
|
||||||
file.write("\n".join(result_lines))
|
|
||||||
@@ -43,12 +43,14 @@ class BaseCliTest:
|
|||||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||||
|
|
||||||
assert mock.called
|
assert mock.called
|
||||||
assert mock.call_args.args[0][:5] == [
|
assert mock.call_args.args[0] == [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"launch",
|
"launch",
|
||||||
"-m",
|
"-m",
|
||||||
f"axolotl.cli.{command}",
|
f"axolotl.cli.{command}",
|
||||||
str(config_path),
|
str(config_path),
|
||||||
|
"--debug-num-examples",
|
||||||
|
"0",
|
||||||
]
|
]
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ def test_build_command():
|
|||||||
"--batch-size",
|
"--batch-size",
|
||||||
"8",
|
"8",
|
||||||
"--debug",
|
"--debug",
|
||||||
"--nouse-fp16",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user