Compare commits
1 Commits
activeblue
...
3181
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
939023e661 |
@@ -36,4 +36,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||||
if cfg.dpo_use_logits_to_keep is not None:
|
if cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||||
|
if cfg.dpo_disable_output_fp32 is not None:
|
||||||
|
training_args_kwargs["disable_output_fp32"] = cfg.dpo_disable_output_fp32
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -16,3 +16,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_norm_loss: bool | None = False
|
dpo_norm_loss: bool | None = False
|
||||||
|
disable_output_fp32: bool | None = False
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class PatchManager:
|
|||||||
# self._apply_flex_attention_patches()
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
self._apply_chunked_cross_entropy_patch()
|
self._apply_chunked_cross_entropy_patch()
|
||||||
|
self._apply_dpo_disable_output_fp32_patch()
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
@@ -107,6 +108,16 @@ class PatchManager:
|
|||||||
else:
|
else:
|
||||||
patch_chunked_ce_loss_fn()
|
patch_chunked_ce_loss_fn()
|
||||||
|
|
||||||
|
def _apply_dpo_disable_output_fp32_patch(self):
|
||||||
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
|
if self.cfg.rl in {RLType.DPO, RLType.IPO} and self.cfg.dpo_disable_output_fp32:
|
||||||
|
from axolotl.monkeypatch.trainer.dpo_chunked import (
|
||||||
|
patch_dpo_disable_output_fp32,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_dpo_disable_output_fp32()
|
||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.context_parallel_size > 1 or (
|
if self.cfg.context_parallel_size > 1 or (
|
||||||
|
|||||||
90
src/axolotl/monkeypatch/trainer/dpo_chunked.py
Normal file
90
src/axolotl/monkeypatch/trainer/dpo_chunked.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Monkeypatch helpers to reduce fp32 materialization during DPO training."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
_PATCHED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_patch_targets(model) -> Iterable[torch.nn.Module]:
|
||||||
|
current = model
|
||||||
|
seen: set[int] = set()
|
||||||
|
while current is not None and id(current) not in seen:
|
||||||
|
seen.add(id(current))
|
||||||
|
yield current
|
||||||
|
current = getattr(current, "module", None)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_unwrapped_forward(module):
|
||||||
|
forward = getattr(module, "forward", None)
|
||||||
|
if forward is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(forward, "__wrapped__"):
|
||||||
|
unwrapped = forward.__wrapped__
|
||||||
|
return MethodType(unwrapped, module)
|
||||||
|
|
||||||
|
original = getattr(module, "_original_forward", None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
func = original.__func__ if hasattr(original, "__func__") else original
|
||||||
|
return MethodType(func, module)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _temporarily_disable_output_fp32(model):
|
||||||
|
patched = []
|
||||||
|
for target in _iter_patch_targets(model):
|
||||||
|
replacement = _resolve_unwrapped_forward(target)
|
||||||
|
if replacement is None:
|
||||||
|
continue
|
||||||
|
patched.append((target, target.forward, replacement))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for module, _, replacement in patched:
|
||||||
|
module.forward = replacement
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for module, original_forward, _ in reversed(patched):
|
||||||
|
module.forward = original_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_fp32_outputs(output: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
if not isinstance(output, dict):
|
||||||
|
return output
|
||||||
|
|
||||||
|
for key, value in output.items():
|
||||||
|
if torch.is_tensor(value) and value.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
output[key] = value.float()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def patch_dpo_disable_output_fp32():
|
||||||
|
"""Patch TRL's DPOTrainer to skip Accelerate's convert_to_fp32 wrapper when requested."""
|
||||||
|
global _PATCHED
|
||||||
|
if _PATCHED:
|
||||||
|
return
|
||||||
|
|
||||||
|
original_concatenated_forward = DPOTrainer.concatenated_forward
|
||||||
|
|
||||||
|
def patched_concatenated_forward(self, model, batch, is_ref_model: bool = False):
|
||||||
|
if not getattr(self.args, "disable_output_fp32", False):
|
||||||
|
return original_concatenated_forward(
|
||||||
|
self, model, batch, is_ref_model=is_ref_model
|
||||||
|
)
|
||||||
|
|
||||||
|
with _temporarily_disable_output_fp32(model):
|
||||||
|
result = original_concatenated_forward(
|
||||||
|
self, model, batch, is_ref_model=is_ref_model
|
||||||
|
)
|
||||||
|
return _cast_fp32_outputs(result)
|
||||||
|
|
||||||
|
DPOTrainer.concatenated_forward = patched_concatenated_forward
|
||||||
|
_PATCHED = True
|
||||||
@@ -160,6 +160,12 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dpo_use_logits_to_keep: bool | None = None
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
|
dpo_disable_output_fp32: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Set to true to bypass Accelerate's automatic fp32 upcast in DPO forward passes and rely on chunked computations for lower VRAM usage."
|
||||||
|
},
|
||||||
|
)
|
||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user