Data loader refactor (#2707)
* data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test
This commit is contained in:
@@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -451,15 +451,19 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
# Only use mock for the commented out configs
|
||||
if dataset_name is not None:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_w_config"
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
mock_load_dataset.return_value = request.getfixturevalue(
|
||||
dataset_name
|
||||
)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
else:
|
||||
# Load actual datasets for orpo_cfg and kto_cfg
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
builder.train_dataset = train_dataset
|
||||
builder.eval_dataset = eval_dataset
|
||||
|
||||
@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
@@ -59,8 +58,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -105,8 +103,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -134,8 +131,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
|
||||
@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.train import train
|
||||
@@ -160,8 +159,7 @@ class TestPluginHooks:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -84,8 +83,7 @@ class TestKnowledgeDistillation:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -115,8 +113,7 @@ class TestKnowledgeDistillation:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -57,8 +56,7 @@ class LigerIntegrationTestCase:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -104,8 +102,7 @@ class LigerIntegrationTestCase:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -88,8 +87,7 @@ class TestLLMCompressorIntegration:
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
try:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -105,7 +105,7 @@ def start_vllm(
|
||||
print(f"{i}: VLLM server failed to start: {str(exc)}")
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
time.sleep(period_seconds)
|
||||
|
||||
192
tests/e2e/multigpu/test_locking.py
Normal file
192
tests/e2e/multigpu/test_locking.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for FileLockLoader class."""
|
||||
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestFileLockLoader:
|
||||
"""Class with tests for FileLockLoader."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield Path(tmp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def cfg(self, temp_dir):
|
||||
"""Create a test configuration."""
|
||||
return DictDefault({"dataset_prepared_path": str(temp_dir)})
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self, cfg):
|
||||
"""Create a FileLockLoader instance for testing."""
|
||||
return FileLockLoader(cfg)
|
||||
|
||||
def test_load_first_process(self, loader):
|
||||
"""Test load() when no ready flag exists (first process)."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should call the load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "test_data"
|
||||
|
||||
# Should create the ready flag
|
||||
assert loader.ready_flag_path.exists()
|
||||
|
||||
def test_load_subsequent_process(self, loader):
|
||||
"""Test load() when ready flag already exists (subsequent process)."""
|
||||
# Create ready flag first
|
||||
loader.ready_flag_path.touch()
|
||||
|
||||
mock_load_fn = Mock(return_value="loaded_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should still call load function (to load the prepared data)
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "loaded_data"
|
||||
|
||||
def test_load_concurrent_processes(self, cfg):
|
||||
"""Test that concurrent processes coordinate correctly."""
|
||||
results = []
|
||||
call_count = 0
|
||||
|
||||
def slow_load_fn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate slow loading
|
||||
return f"data_{call_count}"
|
||||
|
||||
def worker():
|
||||
loader = FileLockLoader(cfg)
|
||||
result = loader.load(slow_load_fn)
|
||||
results.append(result)
|
||||
|
||||
# Start multiple threads simultaneously
|
||||
threads = [threading.Thread(target=worker) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Only one thread should have done the initial loading
|
||||
# All should return data, but the load function should be called
|
||||
# once by the first process and once by each subsequent process
|
||||
assert len(results) == 3
|
||||
assert all(result.startswith("data_") for result in results)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
|
||||
"""Test that processes wait for the ready flag to appear."""
|
||||
mock_load_fn = Mock(return_value="waiting_data")
|
||||
mock_ready_flag_path = Mock()
|
||||
exists_call_count = 0
|
||||
|
||||
def mock_exists():
|
||||
nonlocal exists_call_count
|
||||
exists_call_count += 1
|
||||
|
||||
if exists_call_count == 1:
|
||||
# First check: ready flag exists (not first process)
|
||||
return True
|
||||
if exists_call_count <= 3:
|
||||
# While loop checks: flag doesn't exist yet
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_ready_flag_path.exists.side_effect = mock_exists
|
||||
|
||||
# Replace the ready_flag_path with our mock
|
||||
original_path = loader.ready_flag_path
|
||||
loader.ready_flag_path = mock_ready_flag_path
|
||||
|
||||
try:
|
||||
result = loader.load(mock_load_fn)
|
||||
finally:
|
||||
# Restore original path
|
||||
loader.ready_flag_path = original_path
|
||||
|
||||
# Should have slept twice while waiting
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(1)
|
||||
|
||||
# Should eventually call load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "waiting_data"
|
||||
|
||||
def test_complete_workflow_with_cleanup(self, loader):
|
||||
"""Test the complete load -> cleanup workflow."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
# First process calls load (this should set up counter)
|
||||
result = loader.load(mock_load_fn)
|
||||
assert result == "test_data"
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.exists()
|
||||
|
||||
# Cleanup should remove everything since there's only one process
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_multiple_processes_workflow(self, loader):
|
||||
"""Test workflow with multiple processes."""
|
||||
# Simulate multiple processes by manually setting up counter
|
||||
loader.ready_flag_path.touch()
|
||||
loader.counter_path.write_text("3") # 3 processes
|
||||
|
||||
# First process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "2"
|
||||
|
||||
# Second process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "1"
|
||||
|
||||
# Last process cleanup
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_load_exception_handling(self, loader):
|
||||
"""Test behavior when load_fn raises an exception."""
|
||||
|
||||
def failing_load_fn():
|
||||
raise ValueError("Load failed")
|
||||
|
||||
with pytest.raises(ValueError, match="Load failed"):
|
||||
loader.load(failing_load_fn)
|
||||
|
||||
# Ready flag should not be created on failure
|
||||
assert not loader.ready_flag_path.exists()
|
||||
|
||||
def test_file_lock_called(self, loader):
|
||||
"""Test that FileLock is properly used."""
|
||||
mock_load_fn = Mock(return_value="locked_data")
|
||||
|
||||
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
|
||||
mock_context = MagicMock()
|
||||
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
|
||||
mock_filelock.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
loader.load(mock_load_fn)
|
||||
|
||||
# Verify FileLock was called with correct path
|
||||
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
|
||||
|
||||
# Verify context manager was used
|
||||
mock_filelock.return_value.__enter__.assert_called_once()
|
||||
mock_filelock.return_value.__exit__.assert_called_once()
|
||||
@@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -108,8 +106,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
import transformers
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -75,8 +74,7 @@ class TestActivationCheckpointing:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -73,8 +72,7 @@ class TestFAXentropyLlama:
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -63,8 +62,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -105,8 +103,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -62,8 +61,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -64,8 +63,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -107,8 +105,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -114,8 +112,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class TestMistral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -102,8 +100,7 @@ class TestMistral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for mixtral
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -57,8 +56,7 @@ class TestMixtral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -96,8 +94,7 @@ class TestMixtral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -112,8 +110,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import subprocess
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -67,8 +66,7 @@ class TestResumeLlama:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -78,7 +76,6 @@ class TestResumeLlama:
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ e2e tests for unsloth qlora
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -68,8 +67,7 @@ class TestUnslothQLoRA:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -119,8 +117,7 @@ class TestUnslothQLoRA:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -175,8 +172,7 @@ class TestUnslothQLoRA:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -59,8 +58,7 @@ class TestPackedFlex(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for relora llama
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -71,8 +70,7 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -72,8 +71,7 @@ class TestDeepseekV3:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -122,8 +120,7 @@ class TestDeepseekV3:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
"""E2E tests for lora llama"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for llama pretrain
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -54,8 +53,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -99,8 +97,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -66,8 +65,7 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -122,8 +120,7 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -164,8 +161,7 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -69,8 +68,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -121,8 +119,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -68,8 +67,7 @@ class TestGemma3Text:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -119,8 +117,7 @@ class TestGemma3Text:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -51,8 +50,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -99,8 +97,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -144,8 +141,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -185,8 +181,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
E2E tests for llama pretrain
|
||||
"""
|
||||
"""E2E tests for llama pretrain"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,7 @@ from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestPretrainLlama:
|
||||
"""
|
||||
Test case for Llama models w pretraining
|
||||
"""
|
||||
"""Test case for Llama models w pretraining"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
@@ -66,8 +61,7 @@ class TestPretrainLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -106,8 +104,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -55,8 +54,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -57,8 +56,7 @@ class TestMamba(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -61,8 +60,7 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -67,8 +66,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -123,8 +121,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -61,8 +60,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -107,8 +105,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -154,8 +151,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -194,8 +190,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -242,8 +237,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -58,8 +57,7 @@ class TestPackedLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -58,8 +57,7 @@ class TestPhi(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -108,8 +106,7 @@ class TestPhi(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -54,8 +53,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for QAT
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -64,8 +63,7 @@ class TestQATLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for reward model lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -63,8 +62,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -57,8 +56,7 @@ class TestCustomSchedulers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,8 +6,9 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
@@ -55,7 +56,8 @@ class TestDPOChatml:
|
||||
# test that dpo.load works
|
||||
load_dpo("chatml", cfg)
|
||||
# now actually load the datasets with the strategy
|
||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_ds, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||
assert "chosen" in train_ds[0]
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Test dataset loading under various conditions.
|
||||
"""
|
||||
"""Test dataset loading under various conditions."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -12,8 +11,9 @@ from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||
def tokenizer(
|
||||
self, tokenizer_huggyllama
|
||||
) -> Generator[PreTrainedTokenizer, Any, Any]:
|
||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||
yield tokenizer_huggyllama
|
||||
|
||||
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -136,7 +144,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -145,7 +156,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
@@ -171,7 +182,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -206,7 +220,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -235,7 +252,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -264,7 +284,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -286,7 +309,8 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" not in train_dataset.features
|
||||
@@ -318,7 +342,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -342,13 +369,16 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" not in train_dataset.features
|
||||
@@ -393,16 +423,18 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
"axolotl.utils.data.shared.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
|
||||
str(prepared_path),
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -437,7 +469,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
|
||||
@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
|
||||
`deduplicate_and_log_datasets` during the execution of the preprocess command.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -14,8 +13,7 @@ from datasets import Dataset
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
self.expected_dataset = Dataset.from_dict(self.expected_data)
|
||||
|
||||
def test_deduplication(self):
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=self.dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_datasets_are_none(self):
|
||||
# Test when both datasets are None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
|
||||
def test_only_train_is_none(self):
|
||||
# Test when only train_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=self.dataset
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_only_eval_is_none(self):
|
||||
# Test when only eval_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.dataset, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
|
||||
def test_exact_duplicates(self):
|
||||
# Test when datasets are exact duplicates
|
||||
duplicate_data = {
|
||||
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset, eval_dataset=dataset
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset, other_dataset=dataset
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset_train, eval_dataset=dataset_eval
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset_train, other_dataset=dataset_eval
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -245,7 +225,9 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -255,7 +237,8 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
@@ -269,7 +252,9 @@ class TestDeduplicateRLDataset:
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -279,9 +264,10 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
cfg.dataset_exact_deduplication = False
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
@@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, _, _, _ = prepare_dataset(
|
||||
train_dataset, _, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
_, eval_dataset, _, _ = prepare_dataset(
|
||||
_, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, eval_dataset, _, _ = prepare_dataset(
|
||||
train_dataset, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase):
|
||||
self.eval_dataset = Dataset.from_dict(self.eval_data)
|
||||
self.dataset = Dataset.from_dict(self.dataset_data)
|
||||
|
||||
@patch(
|
||||
"axolotl.utils.data.utils.sha256",
|
||||
side_effect=lambda x: (
|
||||
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
|
||||
if "sample 5" in x
|
||||
else hashlib.sha256(x.encode("utf-8")).hexdigest()
|
||||
),
|
||||
)
|
||||
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
|
||||
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_train),
|
||||
2,
|
||||
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
2,
|
||||
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
len(self.eval_dataset),
|
||||
"The output eval dataset should have the same number of rows as the input eval dataset.",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(dedup_eval),
|
||||
str(self.eval_dataset),
|
||||
"The string representation of the output eval dataset should be identical to the input eval dataset.",
|
||||
)
|
||||
|
||||
def test_deduplication_dataset_only(self):
|
||||
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
self.assertEqual(
|
||||
len(dedup_dataset), 3, "Dataset should have all original values"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user