using field validator instead of model validator
This commit is contained in:
@@ -1104,16 +1104,19 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@field_validator("sequence_parallel_degree", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sequence_parallel_config(cls, data):
|
def check_sequence_parallel_config(cls, value, info):
|
||||||
if data.get("sequence_parallel_degree", 1) > 1:
|
if not value:
|
||||||
if not data.get("flash_attention"):
|
value = 1
|
||||||
|
|
||||||
|
if value > 1:
|
||||||
|
if not info.data.get("flash_attention"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return value
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
|
|||||||
Reference in New Issue
Block a user