feat: add arg to enable dft in liger (#3125)

* feat: add arg to enable dft in liger

* feat: add tests use_token_scaling

* fix: test

* fix: move check to args
This commit is contained in:
NanoCode012
2025-11-10 21:37:47 +07:00
committed by GitHub
parent d0c846fc5e
commit 11eb36585a
5 changed files with 75 additions and 2 deletions

View File

@@ -2,6 +2,7 @@
Simple end-to-end test for Liger integration
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -62,7 +63,11 @@ class LigerIntegrationTestCase:
check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
@pytest.mark.parametrize(
"liger_use_token_scaling",
[True, False],
)
def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -74,6 +79,7 @@ class LigerIntegrationTestCase:
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"liger_use_token_scaling": liger_use_token_scaling,
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {