TiledMLP support for FSDP2 (#2950)

* make TiledMLP work with FSDP

* cleanup/gc at start of train to prevent large VRAM spike

* chore: lint

* generic function for non-deepspeed training

* unify patch to fix imports

* update readme for ALST and add examples

* make deepspeed attribute on params check more robust

* update with new info from PR review
This commit is contained in:
Wing Lian
2025-07-25 07:15:03 -04:00
committed by GitHub
parent 460e0f9ed9
commit f7ea140838
13 changed files with 330 additions and 26 deletions

View File

@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_conflict(cls, data):
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
if (
data.get("liger_glu_activation") is True
and data.get("tiled_mlp") is True
and not data.get("tiled_mlp_use_original_mlp")
):
raise ValueError(
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data