Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
208f8b253f add validation for DFT 2026-01-13 09:33:04 -05:00
Wing Lian
75ad1a9932 use dynamic finetuning with chunked cross entropy 2026-01-13 09:33:04 -05:00
4 changed files with 82 additions and 11 deletions

View File

@@ -153,9 +153,12 @@ class PatchManager:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks: if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
else: else:
patch_chunked_ce_loss_fn() patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
def _apply_fsdp_patches(self): def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations.""" """Apply patches for FSDP configurations."""

View File

@@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
""" """
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): def __init__(
self,
num_output_chunks: int = 8,
ignore_index: int = -100,
use_dft: bool = False,
):
super().__init__() super().__init__()
self.num_output_chunks = num_output_chunks self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.use_dft = use_dft
def compute_cross_entropy( def compute_cross_entropy(
self, self,
@@ -30,10 +36,30 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
""" """
Upcast logits to fp32 and compute cross entropy loss. Upcast logits to fp32 and compute cross entropy loss.
""" """
return F.cross_entropy( ce_loss = F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
) )
if self.use_dft:
# Compute probabilities and gather the ones corresponding to labels
with torch.no_grad(): # Stop gradient
probs = torch.softmax(logits.float(), dim=-1)
# Create mask for valid tokens (not ignore_index)
valid_mask = labels != self.ignore_index
# Gather probabilities for the correct tokens
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
# Apply mask to only scale valid tokens
label_probs = label_probs * valid_mask
# Avoid multiplication by 0 for ignored tokens
label_probs = torch.where(
valid_mask, label_probs, torch.ones_like(label_probs)
)
# Scale the loss by the probability (DFT)
ce_loss = ce_loss * label_probs
return ce_loss.sum()
def forward( def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor: ) -> torch.Tensor:
@@ -71,16 +97,20 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
return total_loss / total_elements return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): def _build_chunked_ce_loss_fn(
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
loss_fn_ce.compute_cross_entropy = torch.compile( loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor" loss_fn_ce.compute_cross_entropy, backend="inductor"
) )
return loss_fn_ce return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): def get_causal_lm_loss(
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)
def chunked_fix_cross_entropy( def chunked_fix_cross_entropy(
source, source,
@@ -124,10 +154,14 @@ def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
return for_causal_lm_chunked_loss return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): def patch_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
import transformers.loss.loss_utils import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) for_causal_lm_chunked_loss = get_causal_lm_loss(
num_output_chunks, ignore_index, use_dft
)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss for_causal_lm_chunked_loss

View File

@@ -664,6 +664,13 @@ class AxolotlInputConfig(
}, },
) )
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use dynamic fine-tuning for scaled SFT gradients."
},
)
chunked_cross_entropy: bool | None = Field( chunked_cross_entropy: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -434,6 +434,18 @@ class TrainingValidationMixin:
return data return data
@model_validator(mode="before")
@classmethod
def check_ao_optim_fsdp2_offload(cls, data):
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
"offload_params"
):
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
raise ValueError(
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_use_reentrant_mismatch(cls, data): def check_use_reentrant_mismatch(cls, data):
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
return data return data
class CELossValidationMixin:
"""Validation methods related to CE loss configuration."""
@model_validator(mode="before")
@classmethod
def check_dft_loss_fn(cls, data):
if data.get("use_dynamic_finetuning"):
if not data.get("chunked_cross_entropy"):
raise ValueError(
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
)
return data
class LoRAValidationMixin: class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration.""" """Validation methods related to LoRA/QLoRA configuration."""
@@ -1464,6 +1490,7 @@ class ValidationMixin(
DatasetValidationMixin, DatasetValidationMixin,
AttentionValidationMixin, AttentionValidationMixin,
TrainingValidationMixin, TrainingValidationMixin,
CELossValidationMixin,
LoRAValidationMixin, LoRAValidationMixin,
RLValidationMixin, RLValidationMixin,
OptimizationValidationMixin, OptimizationValidationMixin,