diff --git a/src/axolotl/integrations/cut_cross_entropy/args.py b/src/axolotl/integrations/cut_cross_entropy/args.py index 2729ebe2e..22852479a 100644 --- a/src/axolotl/integrations/cut_cross_entropy/args.py +++ b/src/axolotl/integrations/cut_cross_entropy/args.py @@ -41,3 +41,13 @@ class CutCrossEntropyArgs(BaseModel): ) return data + + @model_validator(mode="before") + @classmethod + def check_chunked_cross_entropy_not_set(cls, data): + if data.get("chunked_cross_entropy"): + raise ValueError( + "Cut Cross Entropy does not support chunked cross entropy. " + "Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy." + ) + return data