Compare commits
1 Commits
accelerato
...
3181
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
939023e661 |
@@ -36,4 +36,6 @@ class DPOStrategy:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
if cfg.dpo_disable_output_fp32 is not None:
|
||||
training_args_kwargs["disable_output_fp32"] = cfg.dpo_disable_output_fp32
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -16,3 +16,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
disable_output_fp32: bool | None = False
|
||||
|
||||
@@ -54,6 +54,7 @@ class PatchManager:
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_dpo_disable_output_fp32_patch()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_model_specific_patches()
|
||||
@@ -107,6 +108,16 @@ class PatchManager:
|
||||
else:
|
||||
patch_chunked_ce_loss_fn()
|
||||
|
||||
def _apply_dpo_disable_output_fp32_patch(self):
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
if self.cfg.rl in {RLType.DPO, RLType.IPO} and self.cfg.dpo_disable_output_fp32:
|
||||
from axolotl.monkeypatch.trainer.dpo_chunked import (
|
||||
patch_dpo_disable_output_fp32,
|
||||
)
|
||||
|
||||
patch_dpo_disable_output_fp32()
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.context_parallel_size > 1 or (
|
||||
|
||||
90
src/axolotl/monkeypatch/trainer/dpo_chunked.py
Normal file
90
src/axolotl/monkeypatch/trainer/dpo_chunked.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Monkeypatch helpers to reduce fp32 materialization during DPO training."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from types import MethodType
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from trl import DPOTrainer
|
||||
|
||||
_PATCHED = False
|
||||
|
||||
|
||||
def _iter_patch_targets(model) -> Iterable[torch.nn.Module]:
|
||||
current = model
|
||||
seen: set[int] = set()
|
||||
while current is not None and id(current) not in seen:
|
||||
seen.add(id(current))
|
||||
yield current
|
||||
current = getattr(current, "module", None)
|
||||
|
||||
|
||||
def _resolve_unwrapped_forward(module):
|
||||
forward = getattr(module, "forward", None)
|
||||
if forward is None:
|
||||
return None
|
||||
|
||||
if hasattr(forward, "__wrapped__"):
|
||||
unwrapped = forward.__wrapped__
|
||||
return MethodType(unwrapped, module)
|
||||
|
||||
original = getattr(module, "_original_forward", None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
func = original.__func__ if hasattr(original, "__func__") else original
|
||||
return MethodType(func, module)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _temporarily_disable_output_fp32(model):
|
||||
patched = []
|
||||
for target in _iter_patch_targets(model):
|
||||
replacement = _resolve_unwrapped_forward(target)
|
||||
if replacement is None:
|
||||
continue
|
||||
patched.append((target, target.forward, replacement))
|
||||
|
||||
try:
|
||||
for module, _, replacement in patched:
|
||||
module.forward = replacement
|
||||
yield
|
||||
finally:
|
||||
for module, original_forward, _ in reversed(patched):
|
||||
module.forward = original_forward
|
||||
|
||||
|
||||
def _cast_fp32_outputs(output: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
if not isinstance(output, dict):
|
||||
return output
|
||||
|
||||
for key, value in output.items():
|
||||
if torch.is_tensor(value) and value.dtype in (torch.float16, torch.bfloat16):
|
||||
output[key] = value.float()
|
||||
return output
|
||||
|
||||
|
||||
def patch_dpo_disable_output_fp32():
|
||||
"""Patch TRL's DPOTrainer to skip Accelerate's convert_to_fp32 wrapper when requested."""
|
||||
global _PATCHED
|
||||
if _PATCHED:
|
||||
return
|
||||
|
||||
original_concatenated_forward = DPOTrainer.concatenated_forward
|
||||
|
||||
def patched_concatenated_forward(self, model, batch, is_ref_model: bool = False):
|
||||
if not getattr(self.args, "disable_output_fp32", False):
|
||||
return original_concatenated_forward(
|
||||
self, model, batch, is_ref_model=is_ref_model
|
||||
)
|
||||
|
||||
with _temporarily_disable_output_fp32(model):
|
||||
result = original_concatenated_forward(
|
||||
self, model, batch, is_ref_model=is_ref_model
|
||||
)
|
||||
return _cast_fp32_outputs(result)
|
||||
|
||||
DPOTrainer.concatenated_forward = patched_concatenated_forward
|
||||
_PATCHED = True
|
||||
@@ -160,6 +160,12 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
dpo_use_logits_to_keep: bool | None = None
|
||||
dpo_disable_output_fp32: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Set to true to bypass Accelerate's automatic fp32 upcast in DPO forward passes and rely on chunked computations for lower VRAM usage."
|
||||
},
|
||||
)
|
||||
dpo_label_smoothing: float | None = None
|
||||
dpo_norm_loss: bool | None = None
|
||||
dpo_padding_free: bool | None = None
|
||||
|
||||
Reference in New Issue
Block a user