Compare commits

...

4 Commits

Author SHA1 Message Date
Wing Lian
385736fae1 fix linter issue from merge 2025-01-13 12:55:03 -05:00
Wing Lian
f89e962119 skip over rows in pretraining dataset (#2223)
* skip over rows in pretraining dataset

* update docs
2025-01-13 10:44:45 -05:00
Wing Lian
bc1c9c20e3 assume empty lora dropout means 0.0 and add tests (#2243)
* assume empty lora dropout means 0.0 and add tests

* remove un-necessary arg

* refactor based on pr feedback:

* chore: lint
2025-01-13 10:44:11 -05:00
Wing Lian
dd26cc3c0f add helper to verify the correct model output file exists (#2245)
* add helper to verify the correct model output file exists

* more checks using helper

* chore: lint

* fix import and relora model check

* workaround for trl trainer saves

* remove stray print
2025-01-13 10:43:29 -05:00
33 changed files with 210 additions and 115 deletions

View File

@@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```{.yaml filename="config.yaml"}
pretraining_dataset: # hf path only
pretraining_dataset:
- name:
path:
split:
text_column: # column in dataset with the data, usually `text`
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
...
```

View File

@@ -27,7 +27,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -129,6 +129,7 @@ class PretrainingDataset(BaseModel):
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
skip: Optional[int] = None
class UserDefinedPrompterType(BaseModel):
@@ -367,6 +368,13 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -89,11 +89,13 @@ def prepare_dataset(cfg, tokenizer, processor=None):
split = "train"
name = None
data_files = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
@@ -107,10 +109,14 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset(
load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
iter_ds,
tokenizer,
cfg,
ds_wrapper_partial,

View File

@@ -2,8 +2,6 @@
Simple end-to-end test for Cut Cross Entropy integration
"""
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
@@ -13,6 +11,8 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
# pylint: disable=duplicate-code
@@ -67,7 +67,7 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
@@ -95,4 +95,4 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,7 +1,6 @@
"""
Simple end-to-end test for Liger integration
"""
from pathlib import Path
from e2e.utils import require_torch_2_4_1
@@ -11,6 +10,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
class LigerIntegrationTestCase:
"""
@@ -60,7 +61,7 @@ class LigerIntegrationTestCase:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
@@ -105,4 +106,4 @@ class LigerIntegrationTestCase:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for multipack fft llama using 4d attention masks
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_3_1, with_temp_dir
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -67,7 +66,7 @@ class Test4dMultipackLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
@@ -111,4 +110,4 @@ class Test4dMultipackLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging
import os
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
@@ -15,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -82,7 +81,7 @@ class TestFAXentropyLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"

View File

@@ -5,7 +5,6 @@ E2E tests for falcon
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -69,7 +68,7 @@ class TestFalconPatched(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -109,4 +108,4 @@ class TestFalconPatched(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
@@ -16,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -73,4 +72,4 @@ class TestFusedLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for llama w/ S2 attn
import logging
import os
import unittest
from pathlib import Path
import pytest
@@ -15,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -71,7 +70,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_fft_s2_attn(self, temp_dir):
@@ -111,4 +110,4 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
@@ -16,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -76,7 +75,7 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
@@ -126,4 +125,4 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft_packing(self, temp_dir):
@@ -110,4 +109,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for mixtral
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -66,7 +65,7 @@ class TestMixtral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -108,4 +107,4 @@ class TestMixtral(unittest.TestCase):
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -69,7 +68,7 @@ class TestPhiMultipack(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_qlora_packed(self, temp_dir):
@@ -120,4 +119,4 @@ class TestPhiMultipack(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import logging
import os
import re
import subprocess
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
@@ -16,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -83,7 +82,7 @@ class TestResumeLlama:
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"

View File

@@ -3,7 +3,6 @@ e2e tests for unsloth qlora
"""
import logging
import os
from pathlib import Path
import pytest
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -77,7 +76,7 @@ class TestUnslothQLoRA:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
@@ -127,7 +126,7 @@ class TestUnslothQLoRA:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
@@ -182,7 +181,7 @@ class TestUnslothQLoRA:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +68,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
@@ -113,7 +113,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
@@ -158,7 +158,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
@@ -203,7 +203,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_ipo_lora(self, temp_dir):
@@ -247,7 +247,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_orpo_lora(self, temp_dir):
@@ -294,7 +294,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip(reason="Fix the implementation")
@with_temp_dir
@@ -358,4 +358,4 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for llama pretrain
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -62,7 +61,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
@@ -106,7 +105,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"

View File

@@ -5,7 +5,6 @@ E2E tests for falcon
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -71,7 +70,7 @@ class TestFalcon(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
@@ -124,7 +123,7 @@ class TestFalcon(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -163,4 +162,4 @@ class TestFalcon(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,8 @@ E2E tests for llama
import logging
import os
from pathlib import Path
from e2e.utils import check_model_output_exists
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -60,7 +61,7 @@ class TestLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
@@ -103,7 +104,7 @@ class TestLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
def test_batch_flattening(self, temp_dir):
# pylint: disable=duplicate-code
@@ -142,4 +143,4 @@ class TestLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for llama pretrain
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -64,4 +63,4 @@ class TestPretrainLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +67,7 @@ class TestLlamaVision(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
@@ -113,4 +112,4 @@ class TestLlamaVision(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,4 +64,4 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
@@ -15,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,4 +64,4 @@ class TestMamba(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
@@ -15,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -112,4 +111,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for mixtral
import logging
import os
import unittest
from pathlib import Path
import torch
from transformers.utils import is_torch_bf16_gpu_available
@@ -16,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -79,7 +78,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_qlora_wo_fa2(self, temp_dir):
@@ -133,7 +132,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_16bit_lora_w_fa2(self, temp_dir):
@@ -190,7 +189,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_16bit_lora_wo_fa2(self, temp_dir):
@@ -247,7 +246,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -287,4 +286,4 @@ class TestMixtral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for custom optimizers using Llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import require_torch_2_5_1, with_temp_dir
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,7 +64,7 @@ class TestCustomOptimizers(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@require_torch_2_5_1
@@ -109,7 +108,7 @@ class TestCustomOptimizers(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
@@ -145,4 +144,4 @@ class TestCustomOptimizers(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -67,7 +66,7 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_phi_qlora(self, temp_dir):
@@ -116,4 +115,4 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
).exists(), "Relora model checkpoint not found"
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"

View File

@@ -5,7 +5,6 @@ E2E tests for reward model lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -71,4 +70,4 @@ class TestRewardModelLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -14,6 +14,8 @@ import torch
from packaging import version
from tbparse import SummaryReader
from axolotl.utils.dict import DictDefault
def with_temp_dir(test_func):
@wraps(test_func)
@@ -93,3 +95,27 @@ def check_tensorboard(
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
"""
helper function to check if a model output file exists after training
checks based on adapter or not and if safetensors saves are enabled or not
"""
if cfg.save_safetensors:
if not cfg.adapter:
assert (Path(temp_dir) / "model.safetensors").exists()
else:
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
else:
# check for both, b/c in trl, it often defaults to saving safetensors
if not cfg.adapter:
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
Path(temp_dir) / "model.safetensors"
).exists()
else:
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
Path(temp_dir) / "adapter_model.safetensors"
).exists()

69
tests/test_lora.py Normal file
View File

@@ -0,0 +1,69 @@
"""
tests for loading loras
"""
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# pylint: disable=duplicate-code
minimal_config = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class TestLoRALoad:
"""
Test class for loading LoRA weights
"""
def test_load_lora_weights(self):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"sequence_len": 1024,
}
| minimal_config
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
def test_load_lora_weights_empty_dropout(self):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": None,
"lora_target_linear": True,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"sequence_len": 1024,
}
| minimal_config
)
cfg = validate_config(cfg)
normalize_config(cfg)
assert cfg.lora_dropout == 0.0
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)