This commit is contained in:
Salman Mohammadi
2026-01-15 18:43:31 +00:00
parent d282f32481
commit 170dca9bb9
4 changed files with 142 additions and 0 deletions

View File

@@ -0,0 +1,55 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma-3-1b-fft-dft
sequence_len: 2048
sample_packing: true
# Enable Dynamic Fine-Tuning loss
use_dynamic_finetuning: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -96,6 +96,7 @@ class PatchManager:
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_dft_loss_patch()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -157,6 +158,12 @@ class PatchManager:
else:
patch_chunked_ce_loss_fn()
def _apply_dft_loss_patch(self):
if self.cfg.use_dynamic_finetuning:
from axolotl.monkeypatch.loss.dft import patch_dft_loss_fn
patch_dft_loss_fn()
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.context_parallel_size > 1 or (

View File

@@ -0,0 +1,76 @@
"""Dynamic Fine-Tuning (DFT) loss implementation"""
from typing import Optional
import torch
import torch.nn.functional as F
def selective_log_softmax(logits, index):
"""Memory-efficient log_softmax -> gather"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(
logits, dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values
else:
per_token_logps = []
for row_logits, row_labels in zip(logits, index, strict=True):
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(
dim=-1, index=row_labels.unsqueeze(-1)
).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def get_dft_loss(ignore_index: int = -100):
"""Creates DFT loss function"""
def for_causal_lm_dft_loss(
logits,
labels,
vocab_size: int = None,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""DFT loss: -exp(logprobs).detach() * logprobs"""
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss
return for_causal_lm_dft_loss
def patch_dft_loss_fn(ignore_index: int = -100):
"""Patch transformers to use DFT loss for causal LM"""
import transformers.loss.loss_utils
for_causal_lm_dft_loss = get_dft_loss(ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_dft_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = for_causal_lm_dft_loss

View File

@@ -676,6 +676,10 @@ class AxolotlInputConfig(
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable Dynamic Fine-Tuning loss (DFT)"},
)
tiled_mlp: bool | None = Field(
default=None,