Compare commits

...

1 Commits
main ... 3181

Author SHA1 Message Date
Dan Saunders
939023e661 chunked DPO loss 2025-09-24 17:43:06 -04:00
5 changed files with 110 additions and 0 deletions

View File

@@ -36,4 +36,6 @@ class DPOStrategy:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_disable_output_fp32 is not None:
training_args_kwargs["disable_output_fp32"] = cfg.dpo_disable_output_fp32
return training_args_kwargs

View File

@@ -16,3 +16,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
disable_output_fp32: bool | None = False

View File

@@ -54,6 +54,7 @@ class PatchManager:
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_dpo_disable_output_fp32_patch()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -107,6 +108,16 @@ class PatchManager:
else:
patch_chunked_ce_loss_fn()
def _apply_dpo_disable_output_fp32_patch(self):
from axolotl.utils.schemas.enums import RLType
if self.cfg.rl in {RLType.DPO, RLType.IPO} and self.cfg.dpo_disable_output_fp32:
from axolotl.monkeypatch.trainer.dpo_chunked import (
patch_dpo_disable_output_fp32,
)
patch_dpo_disable_output_fp32()
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.context_parallel_size > 1 or (

View File

@@ -0,0 +1,90 @@
"""Monkeypatch helpers to reduce fp32 materialization during DPO training."""
from __future__ import annotations
from contextlib import contextmanager
from types import MethodType
from typing import Iterable
import torch
from trl import DPOTrainer
_PATCHED = False
def _iter_patch_targets(model) -> Iterable[torch.nn.Module]:
current = model
seen: set[int] = set()
while current is not None and id(current) not in seen:
seen.add(id(current))
yield current
current = getattr(current, "module", None)
def _resolve_unwrapped_forward(module):
forward = getattr(module, "forward", None)
if forward is None:
return None
if hasattr(forward, "__wrapped__"):
unwrapped = forward.__wrapped__
return MethodType(unwrapped, module)
original = getattr(module, "_original_forward", None)
if original is None:
return None
func = original.__func__ if hasattr(original, "__func__") else original
return MethodType(func, module)
@contextmanager
def _temporarily_disable_output_fp32(model):
patched = []
for target in _iter_patch_targets(model):
replacement = _resolve_unwrapped_forward(target)
if replacement is None:
continue
patched.append((target, target.forward, replacement))
try:
for module, _, replacement in patched:
module.forward = replacement
yield
finally:
for module, original_forward, _ in reversed(patched):
module.forward = original_forward
def _cast_fp32_outputs(output: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if not isinstance(output, dict):
return output
for key, value in output.items():
if torch.is_tensor(value) and value.dtype in (torch.float16, torch.bfloat16):
output[key] = value.float()
return output
def patch_dpo_disable_output_fp32():
"""Patch TRL's DPOTrainer to skip Accelerate's convert_to_fp32 wrapper when requested."""
global _PATCHED
if _PATCHED:
return
original_concatenated_forward = DPOTrainer.concatenated_forward
def patched_concatenated_forward(self, model, batch, is_ref_model: bool = False):
if not getattr(self.args, "disable_output_fp32", False):
return original_concatenated_forward(
self, model, batch, is_ref_model=is_ref_model
)
with _temporarily_disable_output_fp32(model):
result = original_concatenated_forward(
self, model, batch, is_ref_model=is_ref_model
)
return _cast_fp32_outputs(result)
DPOTrainer.concatenated_forward = patched_concatenated_forward
_PATCHED = True

View File

@@ -160,6 +160,12 @@ class AxolotlInputConfig(
},
)
dpo_use_logits_to_keep: bool | None = None
dpo_disable_output_fp32: bool | None = Field(
default=None,
json_schema_extra={
"description": "Set to true to bypass Accelerate's automatic fp32 upcast in DPO forward passes and rely on chunked computations for lower VRAM usage."
},
)
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None