Compare commits

..

10 Commits

Author SHA1 Message Date
Dan Saunders
ede973b76c nits 2025-07-28 01:47:40 +00:00
Wing Lian
1d2aa1e467 upgrade to support latest transformers release (#2984)
* upgrade to support latest transformers release

* bump mistral common too

* Fix dependencies
2025-07-27 17:05:12 -04:00
NICOLAS BZRD
430be216d8 add shuffle_before_merging_datasets option to allow independent shuffling of datasets before merging (#2981) [skip ci] 2025-07-27 17:04:56 -04:00
Wing Lian
28804b82e4 don't create a reference model if grpo beta is 0.0 (#2983) [skip ci] 2025-07-27 17:04:42 -04:00
Wing Lian
add3e5076b don't publish to netlify on contributor submissions since it requires auth tokens (#2985) [skip ci]
* don't publish to netlify on contributor submissions since it requires auth tokens

* fix no-tmux build and add contact to motd
2025-07-27 17:04:27 -04:00
NanoCode012
41434f0c28 feat(doc): add all providers to readme (#2972) [skip ci]
* feat(doc): add vastai link

* feat: add cloud providers to readme for more visibility

* add prime intellect, remove Modal as sponsor

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-27 17:03:50 -04:00
Wing Lian
f7ea140838 TiledMLP support for FSDP2 (#2950)
* make TiledMLP work with FSDP

* cleanup/gc at start of train to prevent large VRAM spike

* chore: lint

* generic function for non-deepspeed training

* unify patch to fix imports

* update readme for ALST and add examples

* make deepspeed attribute on params check more robust

* update with new info from PR review
2025-07-25 07:15:03 -04:00
Wing Lian
460e0f9ed9 improve handling of file lock when content is empty (#2959) 2025-07-24 16:10:38 -04:00
Wing Lian
e80faea0db garbage collect on the end of the step if we're going to save a checkpoint (#2971) [skip ci] 2025-07-24 16:10:23 -04:00
Wing Lian
0ff2f172ef Act offload lora fix (#2928) [skip ci]
* fix activation offloading with lora

* update w e2e test

* add docs for error
2025-07-24 16:10:04 -04:00
35 changed files with 1063 additions and 159 deletions

View File

@@ -37,14 +37,14 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -53,6 +53,7 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
id: netlify
with:
publish-dir: './_site'
@@ -67,7 +68,7 @@ jobs:
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
- name: Update PR with preview link
if: ${{ steps.netlify.outcome == 'success' }}
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
uses: marocchino/sticky-pull-request-comment@v2
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -119,14 +119,15 @@ datasets:
## Dataset Processing
| Option | Default | Description |
| ----------------------------- | -------------------------- | --------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
| Option | Default | Description |
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
## LoRA Configuration

View File

@@ -25,6 +25,7 @@
## 🎉 Latest Updates
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
@@ -79,6 +80,20 @@ docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
#### Cloud Providers
<details>
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=github&utm_medium=developer_community&utm_campaign=template_launch_axolotl&utm_content=readme)
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
- [Novita](https://novita.ai/gpus-console?templateId=311)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
</details>
### Your First Fine-tune
```bash
@@ -120,12 +135,6 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
## ❤️ Sponsors
Thank you to our sponsors who help make Axolotl possible:
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📜 License

View File

@@ -19,5 +19,7 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
# Upload coverage to Codecov if CODECOV_TOKEN is available
if [ -n "$CODECOV_TOKEN" ]; then
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
fi

View File

@@ -9,13 +9,15 @@ ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \

View File

@@ -136,3 +136,7 @@ description: Frequently asked questions
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.

View File

@@ -124,10 +124,13 @@ For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Available on:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
- [Novita](https://novita.ai/gpus-console?templateId=311)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
### Google Colab {#sec-colab}

9
examples/alst/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Arctic Long Sequence Training (ALST)
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
techniques. It is a combination of:
- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).

View File

@@ -0,0 +1,53 @@
base_model: meta-llama/Llama-3.1-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: togethercomputer/Long-Data-Collections
type: completion
field: text
data_files:
- pretrain/rp_sub.jsonl.zst
- path: princeton-nlp/TextbookChapters
type: completion
field: chapter
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 500_000
min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
sequence_parallel_degree: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1
evals_per_epoch: 2
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,59 @@
base_model: meta-llama/Llama-3.1-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: togethercomputer/Long-Data-Collections
type: completion
field: text
data_files:
- pretrain/rp_sub.jsonl.zst
- path: princeton-nlp/TextbookChapters
type: completion
field: chapter
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 500_000
min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1
evals_per_epoch: 2
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false # offloading is currently not compatible with SP + torchao optimizer
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -9,7 +9,6 @@ liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
chat_template: llama3
datasets:
- path: mlabonne/FineTome-100k

View File

@@ -15,7 +15,6 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 16
lora_alpha: 32
# Currently, we don't support dropout with our custom Triton kernels

View File

@@ -13,13 +13,13 @@ packaging==23.2
huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.53.2
transformers==4.54.0
tokenizers>=0.21.1
accelerate==1.9.0
datasets==4.0.0
deepspeed>=0.17.0
trl==0.19.1
hf_xet==1.1.2
hf_xet==1.1.5
optimum==1.16.2
hf_transfer
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.7.0
mistral-common==1.8.3

View File

@@ -13,6 +13,8 @@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
Need help with your post-training workloads? Reach out us at contact@axolotl.ai for assistance.
```
cd /workspace
rm -rf /workspace/axolotl

View File

@@ -68,9 +68,10 @@ def parse_requirements(extras_require_map):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
_install_requires.append("xformers==0.0.31")
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
@@ -84,7 +85,9 @@ def parse_requirements(extras_require_map):
else:
_install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
@@ -114,10 +117,10 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.0.post2"],
"flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [
"flash-attn==2.8.0.post2",
"ring-flash-attn>=0.1.5",
"flash-attn==2.8.2",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],
"deepspeed": [
@@ -151,13 +154,12 @@ extras_require = {
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
"vllm==0.10.0",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)

View File

@@ -500,6 +500,7 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (

View File

@@ -4,13 +4,22 @@ Trainer mixin for activation checkpointing w offloading
import contextlib
from peft import PeftModel
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
from trl.models.activation_offloading import get_act_offloading_ctx_manager
from trl.models.activation_offloading import (
NoOpManager,
OffloadActivations,
get_act_offloading_ctx_manager,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class ActivationOffloadingMixin(Trainer):
@@ -21,9 +30,14 @@ class ActivationOffloadingMixin(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.activation_offloading:
self.activation_offload_context = get_act_offloading_ctx_manager(
self.model, use_streams=True
)
if isinstance(self.model, PeftModel):
self.activation_offload_context = get_lora_act_offloading_ctx_manager(
self.model, use_streams=True
)
else:
self.activation_offload_context = get_act_offloading_ctx_manager(
self.model, use_streams=True
)
else:
self.activation_offload_context = contextlib.nullcontext()
@@ -35,3 +49,169 @@ class ActivationOffloadingMixin(Trainer):
def ac_wrap_hf_model(model: nn.Module, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
def get_lora_act_offloading_ctx_manager(
model: nn.Module,
use_pin_memory: bool = True,
use_streams: bool = True,
min_offload_size: int = 1024,
max_fwd_stash_size: int = 5,
warn_if_no_head: bool = True,
) -> OffloadActivations:
"""
Returns the activation offloading context manager for the model. All but the last output Linear in every step will
be offloaded.
If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is
disabled, we return a NoOpManager context manager.
Args:
model (`nn.Module`):
Model to wrap with the activation offloading context manager.
use_pin_memory (`bool`, *optional*, defaults to `True`):
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
be moved back onto GPU more quickly but is a limited resource.
use_streams (`bool`, *optional*, defaults to `True`):
Whether to use streams for performance optimization where the communications get overlapped with the
computation. Requires a torch build after torch-2.5.0.
min_offload_size (`int`, *optional*, defaults to `1024`):
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
do not want to waste bandwidth and resources moving it to CPU and back.
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
runtime.
warn_if_no_head (`bool`, *optional*, defaults to `True`):
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
head is detected.
Returns:
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
# pylint: disable=unnecessary-dunder-call
activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory,
use_streams=use_streams,
min_offload_size=min_offload_size,
max_fwd_stash_size=max_fwd_stash_size,
)
# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive.
output_head_detected = False
noop_ctx = NoOpManager()
# Try to get the actual model if it's wrapped
unwrapped_model = model
if hasattr(unwrapped_model, "module"):
unwrapped_model = unwrapped_model.module
# check for PEFT models
if hasattr(unwrapped_model, "base_model") and hasattr(
unwrapped_model, "peft_config"
):
unwrapped_model = unwrapped_model.base_model
# Check for different types of output heads
if hasattr(unwrapped_model, "output"):
if isinstance(unwrapped_model.output, nn.Module):
unwrapped_model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
unwrapped_model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
elif hasattr(unwrapped_model.output, "linear") and isinstance(
unwrapped_model.output.linear, nn.Module
):
unwrapped_model.output.linear.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
unwrapped_model.output.linear.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
# Check for HuggingFace model output heads
elif hasattr(unwrapped_model, "lm_head"):
unwrapped_model.lm_head.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
unwrapped_model.lm_head.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
# Check for decoder-based models
elif hasattr(unwrapped_model, "decoder"):
decoder = unwrapped_model.decoder
if hasattr(decoder, "output"):
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
decoder.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
# Some models have lm_head in the decoder
elif hasattr(decoder, "lm_head"):
decoder.lm_head.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
decoder.lm_head.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
# Check for transformer models with final layer norm
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(
unwrapped_model, "ln_f"
):
final_norm = (
getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
)
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
final_norm.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
# Check for models with head module
elif hasattr(unwrapped_model, "head") and isinstance(
unwrapped_model.head, nn.Module
):
unwrapped_model.head.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
unwrapped_model.head.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
if not output_head_detected and warn_if_no_head:
LOG.warning(
"During activation offloading, no output head was detected. If your model has an output head, it will be "
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
"passing `warn_if_no_head=False`."
)
for name, module in unwrapped_model.named_modules():
# Disable offloading for any Liger modules
if "liger" in name.lower():
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
module.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
# disable offloading for any submodules to fix LoRA training
if name.endswith("._checkpoint_wrapped_module"):
for _, sub_module in module.named_modules():
sub_module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
sub_module.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
return activations_handling_ctx

View File

@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_conflict(cls, data):
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
if (
data.get("liger_glu_activation") is True
and data.get("tiled_mlp") is True
and not data.get("tiled_mlp_use_original_mlp")
):
raise ValueError(
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data

View File

@@ -102,8 +102,8 @@ def matmul_lora(
del W
if A is not None:
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
A, B = A.t().to(dtype), B.t().to(dtype)
out += (X @ A) @ (s * B)
return out.view(batch, seq_len, -1) if reshape else out

View File

@@ -162,6 +162,7 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)

View File

@@ -66,6 +66,9 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -272,7 +275,9 @@ class PatchManager:
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
from axolotl.monkeypatch.tiled_mlp import (
patch_tiled_mlp,
)
patch_tiled_mlp(
model_type,

View File

@@ -0,0 +1,351 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
"""
import copy
import functools
import sys
import torch
from torch import nn
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def fsdp2_load_full_state_dict(
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...")
import time
start_time = time.time()
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
if hasattr(sharded_meta_param, "device_mesh"):
sharded_param = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
src_data_rank=0,
)
else:
sharded_param = full_tensor
if offload_to_cpu:
sharded_param = sharded_param.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_param)
del full_tensor
full_sd[param_name] = None
model.load_state_dict(sharded_sd, assign=True, strict=True)
end_time = time.time()
LOG.debug(
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
)
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
return model
def get_state_dict(self, model, unwrap=True):
"""
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
precision.
Args:
model (`torch.nn.Module`):
A PyTorch model sent through [`Accelerator.prepare`]
unwrap (`bool`, *optional*, defaults to `True`):
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
Returns:
`dict`: The state dictionary of the model potentially without full precision.
Example:
```python
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> net = torch.nn.Linear(2, 2)
>>> net = accelerator.prepare(net)
>>> state_dict = accelerator.get_state_dict(net)
```
"""
from accelerate import DistributedType
from accelerate.utils import compare_versions
if self.distributed_type == DistributedType.DEEPSPEED:
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
tp_sharding = (
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
)
if zero3_sharding or tp_sharding:
if model.zero_gather_16bit_weights_on_model_save():
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
raise ImportError(
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
)
state_dict = (
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
if tp_sharding
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
)
else:
raise ValueError(
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
)
else:
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
state_dict = clone_tensors_for_torch_save(
self.unwrap_model(model).state_dict()
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
state_dict = model.state_dict()
return state_dict
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from torch.distributed.fsdp import fully_shard
log_bias_dtype_mismatch = False
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to prepare
Returns:
`torch.nn.Module`: Prepared model
"""
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
from accelerate.utils.modeling import get_non_persistent_buffers
from peft import PeftModel
from peft.tuners.lora import LoraLayer
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
is_type_fsdp = isinstance(model, FSDPModule) or (
is_compiled_module(model)
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
)
if is_type_fsdp:
return model
fsdp2_plugin = accelerator.state.fsdp_plugin
original_sd = model.state_dict()
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
# We need the `auto_wrap_policy` original type to create a custom policy function for sharding
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
pass # auto_wrap_policy_type = "transformer"
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
pass # auto_wrap_policy_type = "size"
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
# This is because of `apply_activation_checkpointing` which will can reuse this function
fsdp2_plugin.set_auto_wrap_policy(model)
if fsdp2_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
# Apply activation checkpointing before applying `fully_shard`
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
}
model_has_params4bit = False
for _, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
if param.__class__.__name__ == "Params4bit":
model_has_params4bit = True
break
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
# Also, these buffers aren't getting sharded by default
# We get the FQNs of all non-persistent buffers, to re-register them after
non_persistent_buffer_fqns = get_non_persistent_buffers(
model, recurse=True, fqns=True
)
original_non_persistent_buffers = copy.deepcopy(
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
)
# We move the model to meta device, as then sharding happens on meta device
model = model.to(torch.device("meta"))
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
# We assume `transformers` models have a `tie_weights` method if they support it
if hasattr(model, "tie_weights"):
model.tie_weights()
is_peft_model = isinstance(model, PeftModel)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
fully_shard(model, **fsdp2_kwargs)
if log_bias_dtype_mismatch:
LOG.warning(
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
)
if fsdp2_plugin.cpu_ram_efficient_loading:
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
fsdp2_load_full_state_dict(
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():
buffer_tensor = buffer_tensor.to(accelerator.device)
if "." in fqn:
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
parent_module = model.get_submodule(parent_fqn)
else:
local_buffer_name = fqn
parent_module = model
parent_module.register_buffer(
local_buffer_name, buffer_tensor, persistent=False
)
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
# Needs to be called both here and above
# removing this call makes the have slightly different loss
# removing the call above leads to extra memory usage as explained in the comment above
if hasattr(model, "tie_weights"):
model.tie_weights()
return model
def patch_accelerate_fsdp2():
import accelerate
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
accelerate.Accelerator.get_state_dict = get_state_dict
setattr(
sys.modules["accelerate"],
"Accelerator.get_state_dict",
get_state_dict,
)

View File

@@ -18,10 +18,15 @@ import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -112,7 +117,7 @@ def create_flash_attn_forward_varlen_llama3(
# Handle sliding window
use_sliding_windows = (
_flash_supports_window_size
_flash_supports_window
and sliding_window is not None
and key_states.shape[1] > sliding_window
)

View File

@@ -0,0 +1,11 @@
"""
TiledMLP monkey patches
"""
from .patch import (
patch_tiled_mlp,
)
__all__ = [
"patch_tiled_mlp",
]

View File

@@ -0,0 +1,153 @@
"""
TiledMLP support for DDP, FSDP, and single GPU
"""
import threading
from typing import List
import torch
class TiledMLP(torch.autograd.Function):
"""
TiledMLP implementation using gradient hooks
"""
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
x_requires_grad = x.requires_grad
x = x.detach()
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
# Create a gradient accumulator for parameters
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
# Install hooks for this shard
is_last_shard = i + 1 == shards
grad_accumulator.install_hooks(is_last_shard)
with torch.enable_grad():
output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks
grad_accumulator.cleanup()
del grad_accumulator
return (None, None, x_grad, None, None)
class GradientAccumulator:
"""
Manual gradient accumulator for TiledMLP with configurable precision
Accumulates in specified dtype and rescales the gradient at the end
"""
def __init__(
self,
params: List[torch.nn.Parameter],
total_shards: int,
dtype: torch.dtype | None = None,
):
self.params = params
self.total_shards = total_shards
self.grad_accumulation_dtype = dtype or torch.float32
self.accumulated_grads = {}
self.hooks = []
self.lock = threading.Lock()
self.gradient_scale = 1.0 / total_shards
# Initialize accumulated gradients in the specified dtype
for param in self.params:
if param.grad is not None:
self.accumulated_grads[param] = param.grad.to(
self.grad_accumulation_dtype
)
param.grad = None
else:
self.accumulated_grads[param] = torch.zeros_like(
param, dtype=self.grad_accumulation_dtype
)
def install_hooks(self, is_last_shard: bool):
"""Install gradient hooks that accumulate gradients in higher precision"""
def create_hook(param):
def hook(grad):
with self.lock:
grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
scaled_grad = grad_to_accum_dtype * self.gradient_scale
if param in self.accumulated_grads:
self.accumulated_grads[param] += scaled_grad
else:
self.accumulated_grads[param] = scaled_grad.clone()
# Only assign the averaged gradient on the last shard
if is_last_shard:
param.grad = self.accumulated_grads[param].to(param.dtype)
return param.grad
return None
return hook
# Install hooks on all parameters
for param in self.params:
if param.requires_grad:
hook = param.register_hook(create_hook(param))
self.hooks.append(hook)
def cleanup(self):
"""Remove all installed hooks"""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
del self.accumulated_grads

View File

@@ -12,8 +12,12 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import (
TiledMLP as DeepSpeedTiledMLP,
)
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
try:
# Dynamically import the module and MLP class
@@ -36,6 +40,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
def tiled_mlp_forward(self, x):
# pylint: disable=protected-access
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
@@ -48,14 +53,23 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
else:
num_shards = cfg_num_shards
if not self._compute_params: # pylint: disable=protected-access
self._compute_params = [ # pylint: disable=protected-access
p for p in self.parameters() if p.requires_grad
]
if not self._compute_params:
self._compute_params = [p for p in self.parameters() if p.requires_grad]
compute_params = self._compute_params # pylint: disable=protected-access
compute_params = self._compute_params
if not self._tiled_mlp_dist_impl:
if (
self._compute_params
and any(
hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group")
for p in self._compute_params
)
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else:
self._tiled_mlp_dist_impl = TiledMLP
down_res = TiledMLP.apply(
down_res = self._tiled_mlp_dist_impl.apply(
mlp_forward,
self,
x,
@@ -66,6 +80,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
mlp_cls.forward = tiled_mlp_forward
mlp_cls._compute_params = [] # pylint: disable=protected-access
mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,

View File

@@ -115,8 +115,11 @@ def setup_reference_model(
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
reference_model: bool = True
if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:
reference_model = False
# load the model again for model_ref/baseline
model_loader = ModelLoader(cfg, tokenizer, reference_model=True)
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
model_ref, _ = model_loader.load()
return model_ref

View File

@@ -27,7 +27,11 @@ from transformers import (
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
IntervalStrategy,
SaveStrategy,
)
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -863,10 +867,16 @@ class GCCallback(TrainerCallback):
torch.cuda.empty_cache()
gc.collect()
def on_train_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
self._gc()
def on_step_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.next_gc_on_begin_step == state.global_step:
# pylint: disable=consider-using-in
if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:
self._gc()
def on_step_end(
@@ -879,6 +889,17 @@ class GCCallback(TrainerCallback):
self.next_gc_on_begin_step = state.global_step + 1
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
self._gc()
elif (
args.save_strategy == SaveStrategy.STEPS
and state.save_steps > 0
and state.global_step % state.save_steps == 0
):
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
self._gc()
elif state.global_step >= state.max_steps:
if args.save_strategy == SaveStrategy.STEPS:
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
self._gc()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument

View File

@@ -46,7 +46,8 @@ class FileLockLoader:
def _increment_counter(self):
"""Safely increment the process counter."""
if self.counter_path.exists():
count = int(self.counter_path.read_text().strip())
counter_content = self.counter_path.read_text().strip()
count = int(counter_content) if counter_content else 0
else:
count = 0
self.counter_path.write_text(str(count + 1))
@@ -54,10 +55,11 @@ class FileLockLoader:
def cleanup(self):
"""Clean up ready flag when last process is done."""
with FileLock(str(self.lock_file_path)):
count = int(self.counter_path.read_text().strip())
counter_content = self.counter_path.read_text().strip()
count = int(counter_content) if counter_content else 0
count -= 1
if count == 0:
if count <= 0:
# Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True)

View File

@@ -543,6 +543,12 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
return ds.shuffle(seed=cfg.seed)
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)

View File

@@ -179,6 +179,12 @@ class AxolotlInputConfig(
"description": "If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true."
},
)
shuffle_before_merging_datasets: bool | None = Field(
default=False,
json_schema_extra={
"description": "If true, each dataset in `datasets` will be shuffled before merging. This allows curriculum learning strategies to be applied at the dataset level. Default is false."
},
)
dataset_prepared_path: str | None = Field(
default=None,
json_schema_extra={
@@ -597,7 +603,7 @@ class AxolotlInputConfig(
)
tiled_mlp_use_original_mlp: bool | None = Field(
default=None,
default=True,
json_schema_extra={
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
},

View File

@@ -512,19 +512,6 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
capabilities = data.get("capabilities")
n_gpu = 0
if capabilities and capabilities.get("n_gpu", 0) >= 1:
n_gpu = capabilities.get("n_gpu", 0)
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
raise ValueError(
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1104,16 +1091,10 @@ class ModelCompatibilityValidationMixin:
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
)
self.gradient_checkpointing = True
if self.adapter and "lora" in self.adapter:
LOG.warning(
"offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
)
self.activation_offloading = "legacy"
else:
LOG.warning(
"`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
)
self.activation_offloading = True
LOG.warning(
"`offload` now uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
)
self.activation_offloading = True
if self.gradient_checkpointing == "offload_disk":
LOG.warning(
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
@@ -1122,19 +1103,6 @@ class ModelCompatibilityValidationMixin:
self.activation_offloading = "disk"
return self
@model_validator(mode="after")
def check_activation_offloading_w_lora(self):
if (
self.activation_offloading is True
and self.adapter
and "lora" in self.adapter
):
LOG.warning(
"activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
)
self.activation_offloading = "legacy"
return self
@model_validator(mode="after")
def check_activation_offloading_wo_gc(self):
if self.activation_offloading and not self.gradient_checkpointing:

View File

@@ -0,0 +1,83 @@
"""
E2E tests for activation offloading
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists
# pylint: disable=duplicate-code
class TestActivationOffloading:
"""
E2E test cases for activation offloading
"""
@pytest.mark.parametrize(
"adapter",
["lora", "qlora", None],
)
def test_activation_offloading(
self,
temp_dir,
adapter,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": "auto",
"save_safetensors": True,
"gradient_checkpointing": True,
"activation_offloading": True,
"save_first_step": False,
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
}
)
if adapter == "lora":
cfg["adapter"] = "lora"
if adapter == "qlora":
cfg["adapter"] = "qlora"
cfg["load_in_4bit"] = True
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -21,62 +21,6 @@ class TestActivationOffloading:
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading is True
def test_gc_converts_offload_w_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
adapter="lora",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing="offload",
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_impl_changes_w_lora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
adapter="lora",
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
cfg = (
DictDefault(
gradient_checkpointing=True,
activation_offloading=True,
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
cfg = validate_config(cfg)
assert cfg.gradient_checkpointing is True
assert cfg.activation_offloading == "legacy"
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
cfg = (
DictDefault(