Compare commits
3 Commits
version-de
...
dft
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a0115493d | ||
|
|
7a4f33802d | ||
|
|
170dca9bb9 |
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -47,8 +47,7 @@ jobs:
|
|||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras: fbgemm-gpu
|
||||||
# axolotl_extras: fbgemm-gpu
|
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
|
|||||||
4
.github/workflows/pypi.yml
vendored
4
.github/workflows/pypi.yml
vendored
@@ -48,9 +48,9 @@ jobs:
|
|||||||
id: tag
|
id: tag
|
||||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||||
|
|
||||||
- name: Update version in VERSION file
|
- name: Update version in setup.py
|
||||||
run: |
|
run: |
|
||||||
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
|
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||||
|
|
||||||
- name: Build a source dist
|
- name: Build a source dist
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
53
examples/gemma3/gemma-3-1b-fft-dft.yml
Normal file
53
examples/gemma3/gemma-3-1b-fft-dft.yml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
base_model: google/gemma-3-1b-it
|
||||||
|
|
||||||
|
model_type: Gemma3ForCausalLM
|
||||||
|
cls_model_config: Gemma3TextConfig
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/gemma-3-1b-fft-dft
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
use_dynamic_finetuning: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
@@ -24,9 +24,6 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
|||||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||||
include-package-data = true
|
include-package-data = true
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
|
||||||
version = { file = "VERSION" }
|
|
||||||
|
|
||||||
[tool.setuptools.cmdclass]
|
[tool.setuptools.cmdclass]
|
||||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||||
|
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ liger-kernel==0.6.4
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub>=0.36.0
|
huggingface_hub>=0.36.0
|
||||||
peft>=0.18.1
|
peft>=0.18.0
|
||||||
tokenizers>=0.22.1
|
tokenizers>=0.22.1
|
||||||
transformers==4.57.6
|
transformers==4.57.1
|
||||||
accelerate==1.12.0
|
accelerate==1.12.0
|
||||||
datasets==4.5.0
|
datasets==4.4.2
|
||||||
deepspeed>=0.18.3
|
deepspeed>=0.18.3
|
||||||
trl==0.25.1
|
trl==0.25.1
|
||||||
hf_xet==1.2.0
|
hf_xet==1.2.0
|
||||||
|
|||||||
31
setup.py
31
setup.py
@@ -1,5 +1,6 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
|
import ast
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
@@ -25,7 +26,6 @@ def parse_requirements(extras_require_map):
|
|||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
try:
|
try:
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
install_xformers = platform.machine() != "aarch64"
|
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# skip packages not compatible with OSX
|
# skip packages not compatible with OSX
|
||||||
skip_packages = [
|
skip_packages = [
|
||||||
@@ -62,49 +62,31 @@ def parse_requirements(extras_require_map):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
torch_parts = torch_version.split("+")
|
|
||||||
if len(torch_parts) == 2:
|
|
||||||
torch_cuda_version = torch_parts[1]
|
|
||||||
_dependency_links.append(
|
|
||||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (major, minor) >= (2, 9):
|
if (major, minor) >= (2, 9):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||||
"fbgemm-gpu==1.4.0",
|
|
||||||
"fbgemm-gpu-genai==1.4.2",
|
|
||||||
]
|
|
||||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||||
if not install_xformers:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
elif (major, minor) >= (2, 8):
|
elif (major, minor) >= (2, 8):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||||
if not install_xformers:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
elif (major, minor) >= (2, 7):
|
elif (major, minor) >= (2, 7):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
if install_xformers:
|
|
||||||
_install_requires.append("xformers==0.0.30")
|
_install_requires.append("xformers==0.0.30")
|
||||||
# vllm 0.9.x is incompatible with latest transformers
|
# vllm 0.9.x is incompatible with latest transformers
|
||||||
extras_require_map.pop("vllm")
|
extras_require_map.pop("vllm")
|
||||||
else:
|
else:
|
||||||
if install_xformers:
|
|
||||||
_install_requires.append("xformers==0.0.31")
|
_install_requires.append("xformers==0.0.31")
|
||||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if install_xformers:
|
|
||||||
_install_requires.append("xformers==0.0.29.post3")
|
_install_requires.append("xformers==0.0.29.post3")
|
||||||
# since we only support 2.6.0+cu126
|
# since we only support 2.6.0+cu126
|
||||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||||
extras_require_map.pop("vllm")
|
extras_require_map.pop("vllm")
|
||||||
elif (major, minor) >= (2, 5):
|
elif (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if install_xformers:
|
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.append("xformers==0.0.28.post2")
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
else:
|
else:
|
||||||
@@ -112,7 +94,6 @@ def parse_requirements(extras_require_map):
|
|||||||
extras_require_map.pop("vllm")
|
extras_require_map.pop("vllm")
|
||||||
elif (major, minor) >= (2, 4):
|
elif (major, minor) >= (2, 4):
|
||||||
extras_require_map.pop("vllm")
|
extras_require_map.pop("vllm")
|
||||||
if install_xformers:
|
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.27")
|
_install_requires.append("xformers>=0.0.27")
|
||||||
@@ -129,11 +110,15 @@ def parse_requirements(extras_require_map):
|
|||||||
|
|
||||||
def get_package_version():
|
def get_package_version():
|
||||||
with open(
|
with open(
|
||||||
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
|
Path(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
/ "src"
|
||||||
|
/ "axolotl"
|
||||||
|
/ "__init__.py",
|
||||||
"r",
|
"r",
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
) as fin:
|
) as fin:
|
||||||
version_ = fin.read().strip()
|
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
|
||||||
|
version_ = ast.literal_eval(version_match.group(1))
|
||||||
return version_
|
return version_
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
"""Axolotl - Train and fine-tune large language models"""
|
"""Axolotl - Train and fine-tune large language models"""
|
||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
try:
|
__version__ = "0.13.0.dev"
|
||||||
__version__ = version("axolotl")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
__version__ = "unknown"
|
|
||||||
|
|||||||
@@ -373,6 +373,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||||
|
|
||||||
|
if self.cfg.use_dynamic_finetuning:
|
||||||
|
from axolotl.monkeypatch.loss.dft import dft_loss
|
||||||
|
|
||||||
|
trainer_kwargs["compute_loss_func"] = dft_loss
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
|
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
|
|||||||
98
src/axolotl/monkeypatch/loss/dft.py
Normal file
98
src/axolotl/monkeypatch/loss/dft.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Dynamic Fine-Tuning (DFT) loss implementation"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def selective_log_softmax(logits, index):
|
||||||
|
"""Memory-efficient log_softmax -> gather"""
|
||||||
|
if logits.dtype in [torch.float32, torch.float64]:
|
||||||
|
selected_logits = torch.gather(
|
||||||
|
logits, dim=-1, index=index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||||
|
per_token_logps = selected_logits - logsumexp_values
|
||||||
|
else:
|
||||||
|
per_token_logps = []
|
||||||
|
for row_logits, row_labels in zip(logits, index, strict=True):
|
||||||
|
row_logps = F.log_softmax(row_logits, dim=-1)
|
||||||
|
row_per_token_logps = row_logps.gather(
|
||||||
|
dim=-1, index=row_labels.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
per_token_logps.append(row_per_token_logps)
|
||||||
|
per_token_logps = torch.stack(per_token_logps)
|
||||||
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
|
def get_dft_loss(ignore_index: int = -100):
|
||||||
|
"""Creates DFT loss function"""
|
||||||
|
|
||||||
|
def for_causal_lm_dft_loss(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
vocab_size: int = None,
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
shift_labels: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""DFT loss: -exp(logprobs).detach() * logprobs"""
|
||||||
|
if shift_labels is None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
shift_labels = shift_labels.to(logits.device)
|
||||||
|
|
||||||
|
# Create loss mask
|
||||||
|
loss_mask = shift_labels != ignore_index
|
||||||
|
shift_labels_masked = shift_labels.clone()
|
||||||
|
shift_labels_masked[~loss_mask] = 0
|
||||||
|
|
||||||
|
# Compute log probabilities
|
||||||
|
logprobs = selective_log_softmax(logits, shift_labels_masked)
|
||||||
|
|
||||||
|
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||||
|
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||||
|
|
||||||
|
# Sum over valid tokens and normalize
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
num_items_in_batch = loss_mask.sum()
|
||||||
|
|
||||||
|
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return for_causal_lm_dft_loss
|
||||||
|
|
||||||
|
|
||||||
|
def dft_loss(outputs, labels, num_items_in_batch=None):
|
||||||
|
"""DFT loss compatible with Trainer.compute_loss_func signature.
|
||||||
|
|
||||||
|
This function is designed to be passed to Trainer's compute_loss_func parameter.
|
||||||
|
"""
|
||||||
|
ignore_index = -100
|
||||||
|
|
||||||
|
# Shift labels for causal LM
|
||||||
|
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.to(outputs.logits.device)
|
||||||
|
|
||||||
|
# Create loss mask
|
||||||
|
loss_mask = shift_labels != ignore_index
|
||||||
|
shift_labels_masked = shift_labels.clone()
|
||||||
|
shift_labels_masked[~loss_mask] = 0
|
||||||
|
|
||||||
|
# Compute log probabilities
|
||||||
|
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
|
||||||
|
|
||||||
|
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||||
|
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||||
|
|
||||||
|
# Sum over valid tokens and normalize
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
num_items_in_batch = loss_mask.sum()
|
||||||
|
|
||||||
|
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||||
|
return loss
|
||||||
@@ -676,6 +676,10 @@ class AxolotlInputConfig(
|
|||||||
"description": "Number of chunks to use for chunked cross entropy loss"
|
"description": "Number of chunks to use for chunked cross entropy loss"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
use_dynamic_finetuning: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Enable Dynamic Fine-Tuning loss (DFT)"},
|
||||||
|
)
|
||||||
|
|
||||||
tiled_mlp: bool | None = Field(
|
tiled_mlp: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user