diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 8cd9aacf5..4b40d4085 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -28,7 +28,7 @@ class DPOStrategy: training_args_kwargs["max_completion_length"] = None training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_prompt_length"] = cfg.sequence_len - training_args_kwargs["generate_during_eval"] = cfg.use_wandb + training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval if cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting if cfg.dpo_padding_free is not None: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1530fabe0..44a9a4f06 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -146,6 +146,7 @@ class AxolotlInputConfig( dpo_label_smoothing: float | None = None dpo_norm_loss: bool | None = None dpo_padding_free: bool | None = None + dpo_generate_during_eval: bool | None = None datasets: ( Annotated[