* fix attetion mask with packing * set position ids and use block diagonal attn mask * fix expand mask for multiple batch items, make sure we pad position_ids * don't move masks to cpu * use multi pack dataloader w random sampler * add position_ids back * more fixes for dataloader integration * est total tokens, fix field loop * more fixes, position_ids seems broken * more fixes for sample packing * use distributed sampler, avoid accelerate prepare * use accelerator prepare for dataloader * fix for position_ids w packing * Update src/axolotl/utils/dataloader.py * validation for sample packing and doc * more fixes for 4k and optimizations * optimized expand mask fn * better handling of variance in multipack dataloader length and trainer hanging when it runs out of data * fix rounding of len of batches to int * better handling so that all devices have the same dataloader len * fix step calc for packing * pass sample packing efficiency to training args * add a test for the mask expansion for sequence packing * only process eval dataset for packing if not None * don't split batches when packing * weighted CE losses * weighted CEL fixes * limit packing to sequences of max seq len * seq_len_multiple for packing * make sure the chunk size is an int * sample_packing_seq_len_multiplier config * use cumulative seq len with var len flash attn v2 w packing * properly calculate max len * fix flash-attn, xformers, packing, support chatml * fix chatml system prompt for openorca, legacy tokenizer opts * add chatml * add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test * fix test and pylint checks * more packing and dataset optimizations and fixes * filter w multiple cpus * more fixes and optimizations * fixes and go back to distributed sampler since batch sampler won't work * fix counts by accounting for num devices * fix steps calculation * previous accelerate is still most performant * add numba to requirements. * use custom distributed checks * fix sampler to prevent overfit w new epochs * let's not cleanup the cached datasets * calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier * speed optimizations and set accelerate fsdp env vars * optimize dataset concatenation? * more optimizations for dataset handling * fix import for annotation * manual pre-commit fixes * another sum optimization and bug fix for calc steps * fix packing estimations * fix formatting * pylint problems * add back flash attention branch for handling unpacked sequences seperately * Address PR feedback * add optional sample packing config params to readme
340 lines
8.6 KiB
Python
340 lines
8.6 KiB
Python
"""Module for testing the validation module"""
|
|
|
|
import logging
|
|
import unittest
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.validation import validate_config
|
|
|
|
|
|
class ValidationTest(unittest.TestCase):
|
|
"""
|
|
Test the validation module
|
|
"""
|
|
|
|
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def inject_fixtures(self, caplog):
|
|
self._caplog = caplog
|
|
|
|
def test_load_4bit_deprecate(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"load_4bit": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
validate_config(cfg)
|
|
|
|
def test_batch_size_unused_warning(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"batch_size": 32,
|
|
}
|
|
)
|
|
|
|
with self._caplog.at_level(logging.WARNING):
|
|
validate_config(cfg)
|
|
assert "batch_size is not recommended" in self._caplog.records[0].message
|
|
|
|
def test_qlora(self):
|
|
base_cfg = DictDefault(
|
|
{
|
|
"adapter": "qlora",
|
|
}
|
|
)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"load_in_8bit": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*8bit.*"):
|
|
validate_config(cfg)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"gptq": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*gptq.*"):
|
|
validate_config(cfg)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"load_in_4bit": False,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*4bit.*"):
|
|
validate_config(cfg)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"load_in_4bit": True,
|
|
}
|
|
)
|
|
|
|
validate_config(cfg)
|
|
|
|
def test_qlora_merge(self):
|
|
base_cfg = DictDefault(
|
|
{
|
|
"adapter": "qlora",
|
|
"merge_lora": True,
|
|
}
|
|
)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"load_in_8bit": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*8bit.*"):
|
|
validate_config(cfg)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"gptq": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*gptq.*"):
|
|
validate_config(cfg)
|
|
|
|
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"load_in_4bit": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*4bit.*"):
|
|
validate_config(cfg)
|
|
|
|
def test_hf_use_auth_token(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"push_dataset_to_hub": "namespace/repo",
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
|
|
validate_config(cfg)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"push_dataset_to_hub": "namespace/repo",
|
|
"hf_use_auth_token": True,
|
|
}
|
|
)
|
|
validate_config(cfg)
|
|
|
|
def test_gradient_accumulations_or_batch_size(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"gradient_accumulation_steps": 1,
|
|
"batch_size": 1,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(
|
|
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
|
):
|
|
validate_config(cfg)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"batch_size": 1,
|
|
}
|
|
)
|
|
|
|
validate_config(cfg)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"gradient_accumulation_steps": 1,
|
|
}
|
|
)
|
|
|
|
validate_config(cfg)
|
|
|
|
def test_falcon_fsdp(self):
|
|
regex_exp = r".*FSDP is not supported for falcon models.*"
|
|
|
|
# Check for lower-case
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "tiiuae/falcon-7b",
|
|
"fsdp": ["full_shard", "auto_wrap"],
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=regex_exp):
|
|
validate_config(cfg)
|
|
|
|
# Check for upper-case
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "Falcon-7b",
|
|
"fsdp": ["full_shard", "auto_wrap"],
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=regex_exp):
|
|
validate_config(cfg)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "tiiuae/falcon-7b",
|
|
}
|
|
)
|
|
|
|
validate_config(cfg)
|
|
|
|
def test_mpt_gradient_checkpointing(self):
|
|
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
|
|
|
# Check for lower-case
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "mosaicml/mpt-7b",
|
|
"gradient_checkpointing": True,
|
|
}
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=regex_exp):
|
|
validate_config(cfg)
|
|
|
|
def test_flash_optimum(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"flash_optimum": True,
|
|
"adapter": "lora",
|
|
}
|
|
)
|
|
|
|
with self._caplog.at_level(logging.WARNING):
|
|
validate_config(cfg)
|
|
assert any(
|
|
"BetterTransformers probably doesn't work with PEFT adapters"
|
|
in record.message
|
|
for record in self._caplog.records
|
|
)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"flash_optimum": True,
|
|
}
|
|
)
|
|
|
|
with self._caplog.at_level(logging.WARNING):
|
|
validate_config(cfg)
|
|
assert any(
|
|
"probably set bfloat16 or float16" in record.message
|
|
for record in self._caplog.records
|
|
)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"flash_optimum": True,
|
|
"fp16": True,
|
|
}
|
|
)
|
|
regex_exp = r".*AMP is not supported.*"
|
|
|
|
with pytest.raises(ValueError, match=regex_exp):
|
|
validate_config(cfg)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"flash_optimum": True,
|
|
"bf16": True,
|
|
}
|
|
)
|
|
regex_exp = r".*AMP is not supported.*"
|
|
|
|
with pytest.raises(ValueError, match=regex_exp):
|
|
validate_config(cfg)
|
|
|
|
def test_adamw_hyperparams(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"optimizer": None,
|
|
"adam_epsilon": 0.0001,
|
|
}
|
|
)
|
|
|
|
with self._caplog.at_level(logging.WARNING):
|
|
validate_config(cfg)
|
|
assert any(
|
|
"adamw hyperparameters found, but no adamw optimizer set"
|
|
in record.message
|
|
for record in self._caplog.records
|
|
)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"optimizer": "adafactor",
|
|
"adam_beta1": 0.0001,
|
|
}
|
|
)
|
|
|
|
with self._caplog.at_level(logging.WARNING):
|
|
validate_config(cfg)
|
|
assert any(
|
|
"adamw hyperparameters found, but no adamw optimizer set"
|
|
in record.message
|
|
for record in self._caplog.records
|
|
)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"adam_beta1": 0.9,
|
|
"adam_beta2": 0.99,
|
|
"adam_epsilon": 0.0001,
|
|
}
|
|
)
|
|
|
|
validate_config(cfg)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"optimizer": "adafactor",
|
|
}
|
|
)
|
|
|
|
validate_config(cfg)
|
|
|
|
def test_packing(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"max_packed_sequence_len": 2048,
|
|
}
|
|
)
|
|
with self._caplog.at_level(logging.WARNING):
|
|
validate_config(cfg)
|
|
assert any(
|
|
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
|
in record.message
|
|
for record in self._caplog.records
|
|
)
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"max_packed_sequence_len": 2048,
|
|
"sample_packing": True,
|
|
}
|
|
)
|
|
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
|
with pytest.raises(ValueError, match=regex_exp):
|
|
validate_config(cfg)
|