feat: add arg to enable dft in liger (#3125)
* feat: add arg to enable dft in liger * feat: add tests use_token_scaling * fix: test * fix: move check to args
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -35,6 +35,15 @@ class LigerArgs(BaseModel):
|
||||
liger_glu_activation: bool | None = None
|
||||
liger_cross_entropy: bool | None = None
|
||||
liger_fused_linear_cross_entropy: bool | None = None
|
||||
liger_use_token_scaling: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enables use_token_scaling in fused_linear_cross_entropy. "
|
||||
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -75,6 +84,18 @@ class LigerArgs(BaseModel):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_liger_use_token_scaling_flce(cls, data):
|
||||
if data.get("liger_use_token_scaling") and not data.get(
|
||||
"liger_fused_linear_cross_entropy"
|
||||
):
|
||||
raise ValueError(
|
||||
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
||||
# TODO @SalmanMohammadi this is a larger fix - investigate
|
||||
|
||||
Reference in New Issue
Block a user