Compare commits
2 Commits
attn-imple
...
dynamic-sf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
208f8b253f | ||
|
|
75ad1a9932 |
@@ -153,9 +153,12 @@ class PatchManager:
|
|||||||
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
||||||
|
|
||||||
if self.cfg.chunked_cross_entropy_num_chunks:
|
if self.cfg.chunked_cross_entropy_num_chunks:
|
||||||
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
|
patch_chunked_ce_loss_fn(
|
||||||
|
self.cfg.chunked_cross_entropy_num_chunks,
|
||||||
|
use_dft=self.cfg.use_dynamic_finetuning,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
patch_chunked_ce_loss_fn()
|
patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
|
||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
|
|||||||
@@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
|
|||||||
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
|
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_output_chunks: int = 8,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
use_dft: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_output_chunks = num_output_chunks
|
self.num_output_chunks = num_output_chunks
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
|
self.use_dft = use_dft
|
||||||
|
|
||||||
def compute_cross_entropy(
|
def compute_cross_entropy(
|
||||||
self,
|
self,
|
||||||
@@ -30,10 +36,30 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Upcast logits to fp32 and compute cross entropy loss.
|
Upcast logits to fp32 and compute cross entropy loss.
|
||||||
"""
|
"""
|
||||||
return F.cross_entropy(
|
ce_loss = F.cross_entropy(
|
||||||
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
|
logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_dft:
|
||||||
|
# Compute probabilities and gather the ones corresponding to labels
|
||||||
|
with torch.no_grad(): # Stop gradient
|
||||||
|
probs = torch.softmax(logits.float(), dim=-1)
|
||||||
|
# Create mask for valid tokens (not ignore_index)
|
||||||
|
valid_mask = labels != self.ignore_index
|
||||||
|
# Gather probabilities for the correct tokens
|
||||||
|
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
|
||||||
|
# Apply mask to only scale valid tokens
|
||||||
|
label_probs = label_probs * valid_mask
|
||||||
|
# Avoid multiplication by 0 for ignored tokens
|
||||||
|
label_probs = torch.where(
|
||||||
|
valid_mask, label_probs, torch.ones_like(label_probs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale the loss by the probability (DFT)
|
||||||
|
ce_loss = ce_loss * label_probs
|
||||||
|
|
||||||
|
return ce_loss.sum()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
|
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -71,16 +97,20 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
|
|||||||
return total_loss / total_elements
|
return total_loss / total_elements
|
||||||
|
|
||||||
|
|
||||||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
def _build_chunked_ce_loss_fn(
|
||||||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
|
||||||
|
):
|
||||||
|
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
|
||||||
loss_fn_ce.compute_cross_entropy = torch.compile(
|
loss_fn_ce.compute_cross_entropy = torch.compile(
|
||||||
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
||||||
)
|
)
|
||||||
return loss_fn_ce
|
return loss_fn_ce
|
||||||
|
|
||||||
|
|
||||||
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
def get_causal_lm_loss(
|
||||||
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
|
||||||
|
):
|
||||||
|
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)
|
||||||
|
|
||||||
def chunked_fix_cross_entropy(
|
def chunked_fix_cross_entropy(
|
||||||
source,
|
source,
|
||||||
@@ -124,10 +154,14 @@ def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
|||||||
return for_causal_lm_chunked_loss
|
return for_causal_lm_chunked_loss
|
||||||
|
|
||||||
|
|
||||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
def patch_chunked_ce_loss_fn(
|
||||||
|
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
|
||||||
|
):
|
||||||
import transformers.loss.loss_utils
|
import transformers.loss.loss_utils
|
||||||
|
|
||||||
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
for_causal_lm_chunked_loss = get_causal_lm_loss(
|
||||||
|
num_output_chunks, ignore_index, use_dft
|
||||||
|
)
|
||||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
||||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
||||||
for_causal_lm_chunked_loss
|
for_causal_lm_chunked_loss
|
||||||
|
|||||||
@@ -664,6 +664,13 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_dynamic_finetuning: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use dynamic fine-tuning for scaled SFT gradients."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
chunked_cross_entropy: bool | None = Field(
|
chunked_cross_entropy: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -434,6 +434,18 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_ao_optim_fsdp2_offload(cls, data):
|
||||||
|
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
|
||||||
|
"offload_params"
|
||||||
|
):
|
||||||
|
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
||||||
|
raise ValueError(
|
||||||
|
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_use_reentrant_mismatch(cls, data):
|
def check_use_reentrant_mismatch(cls, data):
|
||||||
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CELossValidationMixin:
|
||||||
|
"""Validation methods related to CE loss configuration."""
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_dft_loss_fn(cls, data):
|
||||||
|
if data.get("use_dynamic_finetuning"):
|
||||||
|
if not data.get("chunked_cross_entropy"):
|
||||||
|
raise ValueError(
|
||||||
|
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|
||||||
@@ -1464,6 +1490,7 @@ class ValidationMixin(
|
|||||||
DatasetValidationMixin,
|
DatasetValidationMixin,
|
||||||
AttentionValidationMixin,
|
AttentionValidationMixin,
|
||||||
TrainingValidationMixin,
|
TrainingValidationMixin,
|
||||||
|
CELossValidationMixin,
|
||||||
LoRAValidationMixin,
|
LoRAValidationMixin,
|
||||||
RLValidationMixin,
|
RLValidationMixin,
|
||||||
OptimizationValidationMixin,
|
OptimizationValidationMixin,
|
||||||
|
|||||||
Reference in New Issue
Block a user