From 939023e6616644103343c30381eff5b6aeeda086 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 24 Sep 2025 17:43:06 -0400 Subject: [PATCH] chunked DPO loss --- src/axolotl/core/trainers/dpo/__init__.py | 2 + src/axolotl/core/trainers/dpo/args.py | 1 + src/axolotl/loaders/patch_manager.py | 11 +++ .../monkeypatch/trainer/dpo_chunked.py | 90 +++++++++++++++++++ src/axolotl/utils/schemas/config.py | 6 ++ 5 files changed, 110 insertions(+) create mode 100644 src/axolotl/monkeypatch/trainer/dpo_chunked.py diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 3aa79c484..1a59c5c13 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -36,4 +36,6 @@ class DPOStrategy: training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss if cfg.dpo_use_logits_to_keep is not None: training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep + if cfg.dpo_disable_output_fp32 is not None: + training_args_kwargs["disable_output_fp32"] = cfg.dpo_disable_output_fp32 return training_args_kwargs diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py index b1e53236e..4c412118a 100644 --- a/src/axolotl/core/trainers/dpo/args.py +++ b/src/axolotl/core/trainers/dpo/args.py @@ -16,3 +16,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): """ dpo_norm_loss: bool | None = False + disable_output_fp32: bool | None = False diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3d4b7b96b..addd01f4b 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -54,6 +54,7 @@ class PatchManager: # self._apply_flex_attention_patches() self._apply_flash_attention_patches() self._apply_chunked_cross_entropy_patch() + self._apply_dpo_disable_output_fp32_patch() self._apply_fsdp_patches() self._apply_adapter_patches() self._apply_model_specific_patches() @@ -107,6 +108,16 @@ class PatchManager: else: patch_chunked_ce_loss_fn() + def _apply_dpo_disable_output_fp32_patch(self): + from axolotl.utils.schemas.enums import RLType + + if self.cfg.rl in {RLType.DPO, RLType.IPO} and self.cfg.dpo_disable_output_fp32: + from axolotl.monkeypatch.trainer.dpo_chunked import ( + patch_dpo_disable_output_fp32, + ) + + patch_dpo_disable_output_fp32() + def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" if self.cfg.context_parallel_size > 1 or ( diff --git a/src/axolotl/monkeypatch/trainer/dpo_chunked.py b/src/axolotl/monkeypatch/trainer/dpo_chunked.py new file mode 100644 index 000000000..9e79f41fc --- /dev/null +++ b/src/axolotl/monkeypatch/trainer/dpo_chunked.py @@ -0,0 +1,90 @@ +"""Monkeypatch helpers to reduce fp32 materialization during DPO training.""" + +from __future__ import annotations + +from contextlib import contextmanager +from types import MethodType +from typing import Iterable + +import torch +from trl import DPOTrainer + +_PATCHED = False + + +def _iter_patch_targets(model) -> Iterable[torch.nn.Module]: + current = model + seen: set[int] = set() + while current is not None and id(current) not in seen: + seen.add(id(current)) + yield current + current = getattr(current, "module", None) + + +def _resolve_unwrapped_forward(module): + forward = getattr(module, "forward", None) + if forward is None: + return None + + if hasattr(forward, "__wrapped__"): + unwrapped = forward.__wrapped__ + return MethodType(unwrapped, module) + + original = getattr(module, "_original_forward", None) + if original is None: + return None + + func = original.__func__ if hasattr(original, "__func__") else original + return MethodType(func, module) + + +@contextmanager +def _temporarily_disable_output_fp32(model): + patched = [] + for target in _iter_patch_targets(model): + replacement = _resolve_unwrapped_forward(target) + if replacement is None: + continue + patched.append((target, target.forward, replacement)) + + try: + for module, _, replacement in patched: + module.forward = replacement + yield + finally: + for module, original_forward, _ in reversed(patched): + module.forward = original_forward + + +def _cast_fp32_outputs(output: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + if not isinstance(output, dict): + return output + + for key, value in output.items(): + if torch.is_tensor(value) and value.dtype in (torch.float16, torch.bfloat16): + output[key] = value.float() + return output + + +def patch_dpo_disable_output_fp32(): + """Patch TRL's DPOTrainer to skip Accelerate's convert_to_fp32 wrapper when requested.""" + global _PATCHED + if _PATCHED: + return + + original_concatenated_forward = DPOTrainer.concatenated_forward + + def patched_concatenated_forward(self, model, batch, is_ref_model: bool = False): + if not getattr(self.args, "disable_output_fp32", False): + return original_concatenated_forward( + self, model, batch, is_ref_model=is_ref_model + ) + + with _temporarily_disable_output_fp32(model): + result = original_concatenated_forward( + self, model, batch, is_ref_model=is_ref_model + ) + return _cast_fp32_outputs(result) + + DPOTrainer.concatenated_forward = patched_concatenated_forward + _PATCHED = True diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0177b19f6..18b689205 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -160,6 +160,12 @@ class AxolotlInputConfig( }, ) dpo_use_logits_to_keep: bool | None = None + dpo_disable_output_fp32: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Set to true to bypass Accelerate's automatic fp32 upcast in DPO forward passes and rely on chunked computations for lower VRAM usage." + }, + ) dpo_label_smoothing: float | None = None dpo_norm_loss: bool | None = None dpo_padding_free: bool | None = None