feat: add arg to enable dft in liger (#3125)
* feat: add arg to enable dft in liger * feat: add tests use_token_scaling * fix: test * fix: move check to args
This commit is contained in:
@@ -18,6 +18,9 @@ liger_rms_norm: true
|
|||||||
liger_glu_activation: true
|
liger_glu_activation: true
|
||||||
liger_layer_norm: true
|
liger_layer_norm: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
|
# FLCE-specific
|
||||||
|
liger_use_token_scaling: true
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -35,6 +35,15 @@ class LigerArgs(BaseModel):
|
|||||||
liger_glu_activation: bool | None = None
|
liger_glu_activation: bool | None = None
|
||||||
liger_cross_entropy: bool | None = None
|
liger_cross_entropy: bool | None = None
|
||||||
liger_fused_linear_cross_entropy: bool | None = None
|
liger_fused_linear_cross_entropy: bool | None = None
|
||||||
|
liger_use_token_scaling: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Enables use_token_scaling in fused_linear_cross_entropy. "
|
||||||
|
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -75,6 +84,18 @@ class LigerArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_liger_use_token_scaling_flce(cls, data):
|
||||||
|
if data.get("liger_use_token_scaling") and not data.get(
|
||||||
|
"liger_fused_linear_cross_entropy"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
||||||
# TODO @SalmanMohammadi this is a larger fix - investigate
|
# TODO @SalmanMohammadi this is a larger fix - investigate
|
||||||
|
|||||||
@@ -48,6 +48,33 @@ class LigerPlugin(BasePlugin):
|
|||||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.liger_use_token_scaling:
|
||||||
|
# Patch FLCE to set token_scaling=True for function and class API
|
||||||
|
from liger_kernel.transformers import functional
|
||||||
|
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||||
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
old_liger_fused_linear_cross_entropy = (
|
||||||
|
functional.liger_fused_linear_cross_entropy
|
||||||
|
)
|
||||||
|
|
||||||
|
def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
|
||||||
|
kwargs["use_token_scaling"] = True
|
||||||
|
return old_liger_fused_linear_cross_entropy(*args, **kwargs)
|
||||||
|
|
||||||
|
functional.liger_fused_linear_cross_entropy = (
|
||||||
|
patched_liger_fused_linear_cross_entropy
|
||||||
|
)
|
||||||
|
|
||||||
|
old_init = LigerFusedLinearCrossEntropyLoss.__init__
|
||||||
|
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
kwargs["use_token_scaling"] = True
|
||||||
|
return old_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
LigerFusedLinearCrossEntropyLoss.__init__ = patched_init
|
||||||
|
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
@@ -62,7 +63,11 @@ class LigerIntegrationTestCase:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@require_torch_2_4_1
|
@require_torch_2_4_1
|
||||||
def test_llama_w_flce(self, temp_dir):
|
@pytest.mark.parametrize(
|
||||||
|
"liger_use_token_scaling",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
@@ -74,6 +79,7 @@ class LigerIntegrationTestCase:
|
|||||||
"liger_glu_activation": True,
|
"liger_glu_activation": True,
|
||||||
"liger_cross_entropy": False,
|
"liger_cross_entropy": False,
|
||||||
"liger_fused_linear_cross_entropy": True,
|
"liger_fused_linear_cross_entropy": True,
|
||||||
|
"liger_use_token_scaling": liger_use_token_scaling,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
|
|||||||
@@ -75,3 +75,19 @@ class TestValidation:
|
|||||||
):
|
):
|
||||||
prepare_plugins(test_cfg)
|
prepare_plugins(test_cfg)
|
||||||
validate_config(test_cfg)
|
validate_config(test_cfg)
|
||||||
|
|
||||||
|
def test_use_token_scaling_require_flce(self, minimal_liger_cfg):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"liger_fused_linear_cross_entropy": False,
|
||||||
|
"liger_use_token_scaling": True,
|
||||||
|
}
|
||||||
|
| minimal_liger_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.",
|
||||||
|
):
|
||||||
|
prepare_plugins(test_cfg)
|
||||||
|
validate_config(test_cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user