Files
axolotl/tests/test_validation.py
Wing Lian 2bb0b78975 Attention mask and position id fixes for packing (#285)
* fix attetion mask with packing

* set position ids and use block diagonal attn mask

* fix expand mask for multiple batch items, make sure we pad position_ids

* don't move masks to cpu

* use multi pack dataloader w random sampler

* add position_ids back

* more fixes for dataloader integration

* est total tokens, fix field loop

* more fixes, position_ids seems broken

* more fixes for sample packing

* use distributed sampler, avoid accelerate prepare

* use accelerator prepare for dataloader

* fix for position_ids w packing

* Update src/axolotl/utils/dataloader.py

* validation for sample packing and doc

* more fixes for 4k and optimizations

* optimized expand mask fn

* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

* fix rounding of len of batches to int

* better handling so that all devices have the same dataloader len

* fix step calc for packing

* pass sample packing efficiency to training args

* add a test for the mask expansion for sequence packing

* only process eval dataset for packing if not None

* don't split batches when packing

* weighted CE losses

* weighted CEL fixes

* limit packing to sequences of max seq len

* seq_len_multiple for packing

* make sure the chunk size is an int

* sample_packing_seq_len_multiplier config

* use cumulative seq len with var len flash attn v2 w packing

* properly calculate max len

* fix flash-attn, xformers, packing, support chatml

* fix chatml system prompt for openorca, legacy tokenizer opts

* add chatml

* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test

* fix test and pylint checks

* more packing and dataset optimizations and fixes

* filter w multiple cpus

* more fixes and optimizations

* fixes and go back to distributed sampler since batch sampler won't work

* fix counts by accounting for num devices

* fix steps calculation

* previous accelerate is still most performant

* add numba to requirements.

* use custom distributed checks

* fix sampler to prevent overfit w new epochs

* let's not cleanup the cached datasets

* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

* speed optimizations and set accelerate fsdp env vars

* optimize dataset concatenation?

* more optimizations for dataset handling

* fix import for annotation

* manual pre-commit fixes

* another sum optimization and bug fix for calc steps

* fix packing estimations

* fix formatting

* pylint problems

* add back flash attention branch for handling unpacked sequences seperately

* Address PR feedback

* add optional sample packing config params to readme
2023-08-12 15:14:56 -04:00

340 lines
8.6 KiB
Python

"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):
"""
Test the validation module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
"load_4bit": True,
}
)
with pytest.raises(ValueError):
validate_config(cfg)
def test_batch_size_unused_warning(self):
cfg = DictDefault(
{
"batch_size": 32,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message
def test_qlora(self):
base_cfg = DictDefault(
{
"adapter": "qlora",
}
)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
}
)
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
}
)
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": False,
}
)
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
}
)
validate_config(cfg)
def test_qlora_merge(self):
base_cfg = DictDefault(
{
"adapter": "qlora",
"merge_lora": True,
}
)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
}
)
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
}
)
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
}
)
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
def test_hf_use_auth_token(self):
cfg = DictDefault(
{
"push_dataset_to_hub": "namespace/repo",
}
)
with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
validate_config(cfg)
cfg = DictDefault(
{
"push_dataset_to_hub": "namespace/repo",
"hf_use_auth_token": True,
}
)
validate_config(cfg)
def test_gradient_accumulations_or_batch_size(self):
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"batch_size": 1,
}
)
with pytest.raises(
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"batch_size": 1,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
}
)
validate_config(cfg)
def test_falcon_fsdp(self):
regex_exp = r".*FSDP is not supported for falcon models.*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
# Check for upper-case
cfg = DictDefault(
{
"base_model": "Falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
}
)
validate_config(cfg)
def test_mpt_gradient_checkpointing(self):
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "mosaicml/mpt-7b",
"gradient_checkpointing": True,
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{
"flash_optimum": True,
"adapter": "lora",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
"fp16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"flash_optimum": True,
"bf16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adam_epsilon": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adafactor",
"adam_beta1": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"adam_epsilon": 0.0001,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"optimizer": "adafactor",
}
)
validate_config(cfg)
def test_packing(self):
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)