diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 3aa79c484..5e160e692 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -36,4 +36,6 @@ class DPOStrategy: training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss if cfg.dpo_use_logits_to_keep is not None: training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep + if cfg.dpo_use_liger_kernel is not None: + training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel return training_args_kwargs diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c9b087ea3..bd6a61177 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -173,6 +173,12 @@ class AxolotlInputConfig( dpo_use_logits_to_keep: bool | None = None dpo_label_smoothing: float | None = None dpo_norm_loss: bool | None = None + + dpo_use_liger_kernel: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use Liger kernel for DPO loss."}, + ) + dpo_padding_free: bool | None = None dpo_generate_during_eval: bool | None = None