CLI cleanup and documentation (#2244)

* CLI init refactor

* fix

* cleanup and (partial) docs

* Adding documentation and continuing cleanup (in progress)

* remove finetune.py script

* continued cleanup and documentation

* pytest fixes

* review comments

* fix

* Fix

* typing fixes

* make sure the batch dataset patcher for multipack is always loaded when handling datasets

* review comments

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Dan Saunders
2025-01-13 12:55:29 -05:00
committed by GitHub
parent f89e962119
commit 1ed4de73b6
60 changed files with 1269 additions and 1259 deletions

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch

View File

@@ -1,76 +0,0 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch

View File

@@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
@@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,8 +4,8 @@ Simple end-to-end test for Liger integration
from e2e.utils import require_torch_2_4_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@@ -60,7 +60,7 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1
@@ -105,5 +105,5 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -109,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import yaml
from axolotl.cli import load_cfg
from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault

View File

@@ -8,8 +8,8 @@ import os
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -80,7 +80,7 @@ class TestFAXentropyLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -107,5 +107,5 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import unittest
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -71,5 +71,5 @@ class TestFusedLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -69,7 +69,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -109,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import unittest
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -74,7 +74,7 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@@ -124,5 +124,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -108,5 +108,5 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -64,7 +64,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -102,7 +102,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__

View File

@@ -6,7 +6,6 @@ import unittest
import transformers
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase):
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
model, _ = load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
@@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase):
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=cli_args.inference)
load_model(cfg, tokenizer, inference=False)
assert (
"torch.jit"

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -118,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import subprocess
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -71,7 +71,7 @@ class TestResumeLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
resume_cfg = cfg | DictDefault(
{
@@ -81,7 +81,7 @@ class TestResumeLlama:
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=resume_cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")

View File

@@ -6,8 +6,8 @@ import os
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -75,7 +75,7 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
@@ -125,7 +125,7 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
@@ -180,7 +180,7 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(

View File

@@ -9,8 +9,8 @@ from pathlib import Path
import pytest
from axolotl.cli import load_rl_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -65,9 +65,9 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
@@ -110,9 +110,9 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
@@ -155,9 +155,9 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip("kto_pair no longer supported in trl")
@@ -200,9 +200,9 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
@@ -244,9 +244,9 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
@@ -291,9 +291,9 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip(reason="Fix the implementation")
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -60,7 +60,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
@@ -104,7 +104,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -69,7 +69,7 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -122,7 +122,7 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -161,5 +161,5 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,8 +7,8 @@ import os
from e2e.utils import check_model_output_exists
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -60,7 +60,7 @@ class TestLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
@@ -103,7 +103,7 @@ class TestLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_batch_flattening(self, temp_dir):
@@ -142,5 +142,5 @@ class TestLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -62,5 +62,5 @@ class TestPretrainLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -66,7 +66,7 @@ class TestLlamaVision(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -111,5 +111,5 @@ class TestLlamaVision(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -63,5 +63,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -63,5 +63,5 @@ class TestMamba(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -110,5 +110,5 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import unittest
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -73,7 +73,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -127,7 +127,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -184,7 +184,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -285,5 +285,5 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -63,7 +63,7 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -107,7 +107,7 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -143,5 +143,5 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -63,7 +63,7 @@ class TestPackedLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class TestPhi(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@@ -114,5 +114,5 @@ class TestPhi(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,8 +7,8 @@ import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -77,7 +77,7 @@ class TestReLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"

View File

@@ -6,8 +6,8 @@ import logging
import os
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -69,5 +69,5 @@ class TestRewardModelLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)