TiledMLP support for FSDP2 (#2950)
* make TiledMLP work with FSDP * cleanup/gc at start of train to prevent large VRAM spike * chore: lint * generic function for non-deepspeed training * unify patch to fix imports * update readme for ALST and add examples * make deepspeed attribute on params check more robust * update with new info from PR review
This commit is contained in:
@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tiled_mlp_conflict(cls, data):
|
||||
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
|
||||
if (
|
||||
data.get("liger_glu_activation") is True
|
||||
and data.get("tiled_mlp") is True
|
||||
and not data.get("tiled_mlp_use_original_mlp")
|
||||
):
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
||||
)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user