* Add s2_attn to hijack flash code * Refactor code to account for s2_attn * Add test for models utils * Add ``s2_attention`` option to llama configs * Add ``s2_attention`` option to README config * Format code to appease linter * chore: lint * Remove xpos and llama-landmark [bad merge] * add e2e smoke tests for shifted sparse attention * remove stray patch from merge * update yml with link to paper for s2_attention/longlora * fix assertion check for full fine tune * increase sequence len for tests and PR feedback updates * reduce context len to 16k for tests * reduce context len to 16k for tests * reduce batch size for larger context len and udpate test to check message * fix test for message --------- Co-authored-by: joecummings <jrcummings@devvm050.nha0.facebook.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
"""Module for testing models utils file."""
|
|
|
|
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.models import load_model
|
|
|
|
|
|
class ModelsUtilsTest(unittest.TestCase):
|
|
"""Testing module for models utils."""
|
|
|
|
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
|
cfg = DictDefault(
|
|
{
|
|
"s2_attention": True,
|
|
"sample_packing": True,
|
|
"base_model": "",
|
|
"model_type": "LlamaForCausalLM",
|
|
}
|
|
)
|
|
|
|
# Mock out call to HF hub
|
|
with patch(
|
|
"axolotl.utils.models.load_model_config"
|
|
) as mocked_load_model_config:
|
|
mocked_load_model_config.return_value = {}
|
|
with pytest.raises(ValueError) as exc:
|
|
# Should error before hitting tokenizer, so we pass in an empty str
|
|
load_model(cfg, tokenizer="")
|
|
assert (
|
|
"shifted-sparse attention does not currently support sample packing"
|
|
in str(exc.value)
|
|
)
|