using field validator instead of model validator

This commit is contained in:
Dan Saunders
2025-03-17 00:28:45 +00:00
parent 1cced52719
commit d187f1f8e2

View File

@@ -1104,16 +1104,19 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before") @field_validator("sequence_parallel_degree", mode="before")
@classmethod @classmethod
def check_sequence_parallel_config(cls, data): def check_sequence_parallel_config(cls, value, info):
if data.get("sequence_parallel_degree", 1) > 1: if not value:
if not data.get("flash_attention"): value = 1
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
return data return value
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):