From 170dca9bb92f36cbb8dfd713e8de18dfe5e7eba7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi <“salman.mohammadi@outlook.com”> Date: Thu, 15 Jan 2026 18:43:31 +0000 Subject: [PATCH] WIP DFT --- examples/gemma3/gemma-3-1b-fft-dft.yml | 55 +++++++++++++++++++ src/axolotl/loaders/patch_manager.py | 7 +++ src/axolotl/monkeypatch/loss/dft.py | 76 ++++++++++++++++++++++++++ src/axolotl/utils/schemas/config.py | 4 ++ 4 files changed, 142 insertions(+) create mode 100644 examples/gemma3/gemma-3-1b-fft-dft.yml create mode 100644 src/axolotl/monkeypatch/loss/dft.py diff --git a/examples/gemma3/gemma-3-1b-fft-dft.yml b/examples/gemma3/gemma-3-1b-fft-dft.yml new file mode 100644 index 000000000..9e274ef18 --- /dev/null +++ b/examples/gemma3/gemma-3-1b-fft-dft.yml @@ -0,0 +1,55 @@ +base_model: google/gemma-3-1b-it + +model_type: Gemma3ForCausalLM +cls_model_config: Gemma3TextConfig + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3 +eot_tokens: + - +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.05 +output_dir: ./outputs/gemma-3-1b-fft-dft + +sequence_len: 2048 +sample_packing: true + +# Enable Dynamic Fine-Tuning loss +use_dynamic_finetuning: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index b7a53c4d5..d917fd5a6 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -96,6 +96,7 @@ class PatchManager: # self._apply_flex_attention_patches() self._apply_flash_attention_patches() self._apply_chunked_cross_entropy_patch() + self._apply_dft_loss_patch() self._apply_fsdp_patches() self._apply_adapter_patches() self._apply_model_specific_patches() @@ -157,6 +158,12 @@ class PatchManager: else: patch_chunked_ce_loss_fn() + def _apply_dft_loss_patch(self): + if self.cfg.use_dynamic_finetuning: + from axolotl.monkeypatch.loss.dft import patch_dft_loss_fn + + patch_dft_loss_fn() + def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" if self.cfg.context_parallel_size > 1 or ( diff --git a/src/axolotl/monkeypatch/loss/dft.py b/src/axolotl/monkeypatch/loss/dft.py new file mode 100644 index 000000000..b16144427 --- /dev/null +++ b/src/axolotl/monkeypatch/loss/dft.py @@ -0,0 +1,76 @@ +"""Dynamic Fine-Tuning (DFT) loss implementation""" + +from typing import Optional + +import torch +import torch.nn.functional as F + + +def selective_log_softmax(logits, index): + """Memory-efficient log_softmax -> gather""" + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather( + logits, dim=-1, index=index.unsqueeze(-1) + ).squeeze(-1) + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values + else: + per_token_logps = [] + for row_logits, row_labels in zip(logits, index, strict=True): + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather( + dim=-1, index=row_labels.unsqueeze(-1) + ).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps + + +def get_dft_loss(ignore_index: int = -100): + """Creates DFT loss function""" + + def for_causal_lm_dft_loss( + logits, + labels, + vocab_size: int = None, + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """DFT loss: -exp(logprobs).detach() * logprobs""" + if shift_labels is None: + # Shift so that tokens < n predict n + labels = F.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + shift_labels = shift_labels.to(logits.device) + + # Create loss mask + loss_mask = shift_labels != ignore_index + shift_labels_masked = shift_labels.clone() + shift_labels_masked[~loss_mask] = 0 + + # Compute log probabilities + logprobs = selective_log_softmax(logits, shift_labels_masked) + + # DFT loss: -exp(logprobs).detach() * logprobs + per_token_loss = -logprobs.exp().detach() * logprobs + + # Sum over valid tokens and normalize + if num_items_in_batch is None: + num_items_in_batch = loss_mask.sum() + + loss = (per_token_loss * loss_mask).sum() / num_items_in_batch + return loss + + return for_causal_lm_dft_loss + + +def patch_dft_loss_fn(ignore_index: int = -100): + """Patch transformers to use DFT loss for causal LM""" + import transformers.loss.loss_utils + + for_causal_lm_dft_loss = get_dft_loss(ignore_index) + transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_dft_loss + transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = for_causal_lm_dft_loss diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index da21df7aa..a804f3a56 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -676,6 +676,10 @@ class AxolotlInputConfig( "description": "Number of chunks to use for chunked cross entropy loss" }, ) + use_dynamic_finetuning: bool | None = Field( + default=None, + json_schema_extra={"description": "Enable Dynamic Fine-Tuning loss (DFT)"}, + ) tiled_mlp: bool | None = Field( default=None,