upgrade liger to 0.4.0 (#1973)
* upgrade liger to 0.3.1 * update docs and example * skip duplicate code check * Update src/axolotl/integrations/liger/args.py Co-authored-by: NanoCode012 <nano@axolotl.ai> * Update README.md Co-authored-by: NanoCode012 <nano@axolotl.ai> * add logging * chore: lint * add test case * upgrade liger and transformers * also upgrade accelerate * use kwargs to support patch release * make sure prepared path is empty for test * use transfromers 4.46.1 since 4.46.2 breaks fsdp --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
0
tests/integrations/__init__.py
Normal file
0
tests/integrations/__init__.py
Normal file
80
tests/integrations/liger.py
Normal file
80
tests/integrations/liger.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
config validation tests for swiglu args
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_base_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module for liger
|
||||
"""
|
||||
|
||||
def test_deprecated_swiglu(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
updated_cfg = validate_config(test_cfg)
|
||||
assert (
|
||||
"The 'liger_swiglu' argument is deprecated"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
assert updated_cfg.liger_swiglu is None
|
||||
assert updated_cfg.liger_glu_activations is False
|
||||
|
||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
"liger_glu_activations": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
||||
):
|
||||
validate_config(test_cfg)
|
||||
Reference in New Issue
Block a user