Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
208f8b253f add validation for DFT 2026-01-13 09:33:04 -05:00
Wing Lian
75ad1a9932 use dynamic finetuning with chunked cross entropy 2026-01-13 09:33:04 -05:00
18 changed files with 114 additions and 268 deletions

View File

@@ -15,11 +15,6 @@
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate)
## Types of changes

View File

@@ -21,8 +21,6 @@ jobs:
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -34,7 +32,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -42,7 +39,6 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -50,7 +46,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -58,7 +53,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -85,7 +79,7 @@ jobs:
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -96,7 +90,7 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -111,8 +105,6 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -124,7 +116,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -132,7 +123,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -140,7 +130,6 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -148,7 +137,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -160,7 +148,6 @@ jobs:
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -171,7 +158,6 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,26 +20,22 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
platforms: "linux/amd64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -65,7 +61,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
@@ -92,26 +88,22 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
platforms: "linux/amd64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -136,7 +128,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
@@ -157,11 +149,11 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 130
cuda_version: 13.0.0
is_latest:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:

View File

@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -2,7 +2,6 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
@@ -32,35 +31,20 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
;; \
esac

View File

@@ -1,53 +0,0 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma-3-1b-fft-dft
sequence_len: 2048
use_dynamic_finetuning: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048

View File

@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
@@ -33,8 +32,8 @@ sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:

View File

@@ -373,11 +373,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_dynamic_finetuning:
from axolotl.monkeypatch.loss.dft import dft_loss
trainer_kwargs["compute_loss_func"] = dft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(

View File

@@ -153,9 +153,12 @@ class PatchManager:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
else:
patch_chunked_ce_loss_fn()
patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""

View File

@@ -5,7 +5,6 @@ from typing import Type
import addict
import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
@@ -154,9 +153,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -178,13 +174,8 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try:
model_config = config_cls.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,

View File

@@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
def __init__(
self,
num_output_chunks: int = 8,
ignore_index: int = -100,
use_dft: bool = False,
):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
self.use_dft = use_dft
def compute_cross_entropy(
self,
@@ -30,10 +36,30 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
ce_loss = F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
)
if self.use_dft:
# Compute probabilities and gather the ones corresponding to labels
with torch.no_grad(): # Stop gradient
probs = torch.softmax(logits.float(), dim=-1)
# Create mask for valid tokens (not ignore_index)
valid_mask = labels != self.ignore_index
# Gather probabilities for the correct tokens
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
# Apply mask to only scale valid tokens
label_probs = label_probs * valid_mask
# Avoid multiplication by 0 for ignored tokens
label_probs = torch.where(
valid_mask, label_probs, torch.ones_like(label_probs)
)
# Scale the loss by the probability (DFT)
ce_loss = ce_loss * label_probs
return ce_loss.sum()
def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
@@ -71,16 +97,20 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
def _build_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def get_causal_lm_loss(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)
def chunked_fix_cross_entropy(
source,
@@ -124,10 +154,14 @@ def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
def patch_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
for_causal_lm_chunked_loss = get_causal_lm_loss(
num_output_chunks, ignore_index, use_dft
)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss

View File

@@ -1,98 +0,0 @@
"""Dynamic Fine-Tuning (DFT) loss implementation"""
from typing import Optional
import torch
import torch.nn.functional as F
def selective_log_softmax(logits, index):
"""Memory-efficient log_softmax -> gather"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(
logits, dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values
else:
per_token_logps = []
for row_logits, row_labels in zip(logits, index, strict=True):
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(
dim=-1, index=row_labels.unsqueeze(-1)
).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def get_dft_loss(ignore_index: int = -100):
"""Creates DFT loss function"""
def for_causal_lm_dft_loss(
logits,
labels,
vocab_size: int = None,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""DFT loss: -exp(logprobs).detach() * logprobs"""
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss
return for_causal_lm_dft_loss
def dft_loss(outputs, labels, num_items_in_batch=None):
"""DFT loss compatible with Trainer.compute_loss_func signature.
This function is designed to be passed to Trainer's compute_loss_func parameter.
"""
ignore_index = -100
# Shift labels for causal LM
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(outputs.logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss

View File

@@ -664,6 +664,13 @@ class AxolotlInputConfig(
},
)
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use dynamic fine-tuning for scaled SFT gradients."
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
json_schema_extra={
@@ -676,10 +683,6 @@ class AxolotlInputConfig(
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable Dynamic Fine-Tuning loss (DFT)"},
)
tiled_mlp: bool | None = Field(
default=None,

View File

@@ -25,12 +25,7 @@ class ModelInputConfig(BaseModel):
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
},
)
cls_model_config: str | None = Field(
default=None,
json_schema_extra={
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
},
)
cls_model_config: str | None = None
tokenizer_config: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -434,6 +434,18 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_ao_optim_fsdp2_offload(cls, data):
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
"offload_params"
):
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
raise ValueError(
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
)
return data
@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
return data
class CELossValidationMixin:
"""Validation methods related to CE loss configuration."""
@model_validator(mode="before")
@classmethod
def check_dft_loss_fn(cls, data):
if data.get("use_dynamic_finetuning"):
if not data.get("chunked_cross_entropy"):
raise ValueError(
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1464,6 +1490,7 @@ class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,
TrainingValidationMixin,
CELossValidationMixin,
LoRAValidationMixin,
RLValidationMixin,
OptimizationValidationMixin,