Distributed/ND-Parallel (#2977)

This commit is contained in:
salman
2025-07-31 20:25:02 +01:00
committed by GitHub
parent 7b68dfafd7
commit 294c7fe7a6
49 changed files with 712 additions and 835 deletions

View File

@@ -16,8 +16,6 @@
Module for handling LIGER input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
Input args for LIGER.
"""
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
@model_validator(mode="before")
@classmethod
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_rms_norm_tensor_parallel(cls, data):
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
raise ValueError(
"`liger_rms_norm` is incompatible with tensor parallelism, "
"see https://github.com/linkedin/Liger-Kernel/issues/826"
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
raise ValueError("Tensor parallelism is not compatible with liger losses.")
return self