feat: add arg to enable dft in liger (#3125)
* feat: add arg to enable dft in liger * feat: add tests use_token_scaling * fix: test * fix: move check to args
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -62,7 +63,11 @@ class LigerIntegrationTestCase:
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@require_torch_2_4_1
|
||||
def test_llama_w_flce(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"liger_use_token_scaling",
|
||||
[True, False],
|
||||
)
|
||||
def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -74,6 +79,7 @@ class LigerIntegrationTestCase:
|
||||
"liger_glu_activation": True,
|
||||
"liger_cross_entropy": False,
|
||||
"liger_fused_linear_cross_entropy": True,
|
||||
"liger_use_token_scaling": liger_use_token_scaling,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -75,3 +75,19 @@ class TestValidation:
|
||||
):
|
||||
prepare_plugins(test_cfg)
|
||||
validate_config(test_cfg)
|
||||
|
||||
def test_use_token_scaling_require_flce(self, minimal_liger_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_fused_linear_cross_entropy": False,
|
||||
"liger_use_token_scaling": True,
|
||||
}
|
||||
| minimal_liger_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.",
|
||||
):
|
||||
prepare_plugins(test_cfg)
|
||||
validate_config(test_cfg)
|
||||
|
||||
Reference in New Issue
Block a user