using field validator instead of model validator
This commit is contained in:
@@ -1104,16 +1104,19 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, data):
|
||||
if data.get("sequence_parallel_degree", 1) > 1:
|
||||
if not data.get("flash_attention"):
|
||||
def check_sequence_parallel_config(cls, value, info):
|
||||
if not value:
|
||||
value = 1
|
||||
|
||||
if value > 1:
|
||||
if not info.data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
return data
|
||||
return value
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
|
||||
Reference in New Issue
Block a user