feat: add arg to enable dft in liger (#3125)
* feat: add arg to enable dft in liger * feat: add tests use_token_scaling * fix: test * fix: move check to args
This commit is contained in:
@@ -75,3 +75,19 @@ class TestValidation:
|
||||
):
|
||||
prepare_plugins(test_cfg)
|
||||
validate_config(test_cfg)
|
||||
|
||||
def test_use_token_scaling_require_flce(self, minimal_liger_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_fused_linear_cross_entropy": False,
|
||||
"liger_use_token_scaling": True,
|
||||
}
|
||||
| minimal_liger_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.",
|
||||
):
|
||||
prepare_plugins(test_cfg)
|
||||
validate_config(test_cfg)
|
||||
|
||||
Reference in New Issue
Block a user