Compare commits

...

29 Commits

Author SHA1 Message Date
Wing Lian
9bdf4b1c23 improve handling and error if fa3 requested but not installeD 2025-05-19 10:11:14 -07:00
Wing Lian
d6f64a3684 handle args to drop dropout 2025-05-18 15:17:40 -07:00
Wing Lian
0735454782 move fa3 tests to multigpu since we only run those on hopper 2025-05-18 15:17:39 -07:00
Wing Lian
bb6464c4c6 use get_device_capability since CI setting in cfg is unreliable 2025-05-18 15:17:39 -07:00
Wing Lian
323a9cb153 handle return sig change for fa3 2025-05-18 15:17:39 -07:00
Wing Lian
b22150751f check for fa first 2025-05-18 15:17:39 -07:00
Wing Lian
8c4bc59bfc fa3 doesn't support dropout_p, fix unpatching 2025-05-18 15:17:39 -07:00
Wing Lian
a064f1c9b4 ci for fa3 2025-05-18 15:17:39 -07:00
Wing Lian
fb5ef6d445 use updated package name for fa3 2025-05-18 15:17:38 -07:00
Wing Lian
34b68ddaae curl with apt instead of pip 2025-05-18 15:17:38 -07:00
Wing Lian
9a3d0c919b make sure curl is installed 2025-05-18 15:17:38 -07:00
Wing Lian
bd34d0b861 install for hopper from pre-built wheel 2025-05-18 15:17:38 -07:00
Wing Lian
37220ab90a install pybind11 for fa3 build 2025-05-18 15:17:38 -07:00
Wing Lian
e1b74d710b update docker args to minimums used and use MAX_JOBS already set as arg 2025-05-18 15:17:38 -07:00
Wing Lian
79daf5b934 reduce max jobs for build of fa3 2025-05-18 15:17:38 -07:00
Wing Lian
ddd7c55576 build hopper w fa3 on torch 2.6 2025-05-18 15:17:37 -07:00
Wing Lian
65c6c98a76 whitespace fix in dockerfile 2025-05-18 15:17:37 -07:00
Wing Lian
4ef2e8293f fix the bash in docker base 2025-05-18 15:17:37 -07:00
Wing Lian
c126d5cd04 fix suffix for tag 2025-05-18 15:17:37 -07:00
Wing Lian
9b0be4f15c fix 12.8 image and add flash-attn v3 hopper base image 2025-05-18 15:17:37 -07:00
Wing Lian
a27b909c5c GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap

* remove overrides needed to support bug

* fix grpo tests

* require more CPU for multigpu to help with torch compile for vllm
2025-05-16 15:47:03 -04:00
xzuyn
6cb07b9d12 Fix for setting adam_beta3 and adam_epsilon2 for CAME Optimizer (#2654) [skip ci]
* make setting `adam_beta3` and `adam_epsilon2` work correctly

* update config docs so users know args are specific to CAME optim

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:50 -04:00
C080
288653adb6 Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting

* cleanup and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:31 -04:00
NanoCode012
3a5b495a74 Fix: improve doc on merge/inference cli visibility (#2674)
* feat: improve visibility for merge doc

* feat: add tip on reuse config between modes
2025-05-16 13:07:40 -04:00
xzuyn
f661858fc4 Print dataset name (#2668) [skip ci] 2025-05-16 13:06:58 -04:00
Eric Meier
c837c4a424 Add missing init file to liger plugin (#2670) [skip ci] 2025-05-16 13:06:46 -04:00
michelyang
c9797de6bb Add num_proc to fix data set slow processing issue (#2681) [skip ci] 2025-05-16 13:06:20 -04:00
Wing Lian
8f8a7afb05 Add ci and images for CUDA 12.8 for B200s (#2683) [skip ci]
* Add ci and images for CUDA 12.8 for B200s

* add comments explaining CI [skip e2e]
2025-05-16 13:06:08 -04:00
NanoCode012
86472715da fix: remove doc string imports in monkeypatches (#2671) [skip ci] 2025-05-16 13:05:55 -04:00
34 changed files with 200 additions and 209 deletions

View File

@@ -47,11 +47,18 @@ jobs:
pytorch: 2.7.0 pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128" - cuda: "128"
cuda_version: 12.6.3 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
suffix: "-hopper"
torch_cuda_arch_list: "9.0+PTX"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -87,7 +94,7 @@ jobs:
context: . context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }} file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-args: | build-args: |
CUDA_VERSION=${{ matrix.cuda_version }} CUDA_VERSION=${{ matrix.cuda_version }}

View File

@@ -31,6 +31,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -94,6 +99,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -32,21 +32,25 @@ jobs:
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: vllm axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" - cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
suffix: "-hopper"
num_gpus: 2
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true"
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
@@ -68,7 +72,6 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |

View File

@@ -295,6 +295,7 @@ jobs:
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st: docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
@@ -341,6 +342,8 @@ jobs:
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 90 timeout-minutes: 90
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st] needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy: strategy:
@@ -365,6 +368,12 @@ jobs:
pytorch: 2.7.0 pytorch: 2.7.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi fi
RUN pip install packaging==23.2 setuptools==75.8.0 RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=90 * 60, timeout=90 * 60,
cpu=8.0, cpu=16.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG, volumes=VOLUME_CONFIG,
) )

View File

@@ -1,5 +1,5 @@
ARG CUDA_VERSION="11.8.0" ARG CUDA_VERSION="12.4.1"
ARG CUDNN_VERSION="8" ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10" ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.5.1"
ARG CUDA="118" ARG CUDA="124"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
@@ -38,6 +38,10 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \ pip3 install flash-attn==2.7.4.post1; \
fi fi

View File

@@ -633,7 +633,9 @@ weight_decay:
# adamw hyperparams # adamw hyperparams
adam_beta1: adam_beta1:
adam_beta2: adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon: adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm # Gradient clipping max norm
max_grad_norm: max_grad_norm:

View File

@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them. format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca 2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
format): format):
```json ```json
@@ -120,6 +120,12 @@ axolotl train my_training.yml
## Common Tasks {#sec-common-tasks} ## Common Tasks {#sec-common-tasks}
::: {.callout-tip}
The same yaml file is used for training, inference, and merging.
:::
### Testing Your Model {#sec-testing} ### Testing Your Model {#sec-testing}
After training, test your model: After training, test your model:
@@ -128,6 +134,16 @@ After training, test your model:
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
``` ```
More details can be found in [Inference](inference.qmd).
### Using a UI {#sec-ui}
Launch a Gradio interface:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
### Preprocessing Data {#sec-preprocessing} ### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first: For large datasets, preprocess first:
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
axolotl preprocess my_training.yml axolotl preprocess my_training.yml
``` ```
### Using a UI {#sec-ui} Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
Launch a Gradio interface: More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
### Merging LoRA weights {#sec-merging-lora}
To merge the LoRA weights back into the base model, run:
```bash ```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
``` ```
The merged model will be saved in the `{output_dir}/merged` directory.
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
## Next Steps {#sec-next-steps} ## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to: Now that you have the basics, you might want to:
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics: Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options - [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats - [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd) - [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd) - [Multi-Node Training](multi-node.qmd)

View File

@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2: if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_beta3:
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
if self.cfg.adam_epsilon: if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.adam_epsilon2:
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
if self.cfg.max_grad_norm: if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9) beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999) beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999) beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30) eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16) eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3) adam_kwargs["betas"] = (beta1, beta2, beta3)
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset: if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config if self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = ( trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs self.cfg.precompute_ref_log_probs

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings import warnings
from contextlib import nullcontext
from typing import Any from typing import Any
import datasets import datasets
@@ -14,7 +13,7 @@ from accelerate.utils import (
broadcast_object_list, broadcast_object_list,
gather, gather,
gather_object, gather_object,
is_peft_model, is_peft_available,
) )
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from torch import nn from torch import nn
@@ -30,15 +29,13 @@ from transformers import (
TrainerCallback, TrainerCallback,
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer from trl import GRPOTrainer
from trl.data_utils import ( from trl.data_utils import (
apply_chat_template, apply_chat_template,
is_conversational, is_conversational,
maybe_apply_chat_template, maybe_apply_chat_template,
) )
from trl.extras.profiling import profiling_context, profiling_decorator from trl.extras.profiling import profiling_context
from trl.import_utils import is_deepspeed_available
from trl.models import unwrap_model_for_generation from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -52,62 +49,12 @@ if is_peft_available():
# pylint: disable=unused-import # pylint: disable=unused-import
from peft import PeftConfig from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers""" """Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling""" """Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
}, },
) )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section # multi-modal section
image_size: int | tuple[int, int] | None = field( image_size: int | tuple[int, int] | None = field(

View File

@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM, KwargsForCausalLM,
) )
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward( def cce_forward(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,

View File

@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import ( from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM, KwargsForCausalLM,
) )
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward( def cce_forward(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,

View File

@@ -20,15 +20,11 @@ from torch import nn
from transformers.cache_utils import Cache, HybridCache from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma3.modeling_gemma3 import ( from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast, Gemma3CausalLMOutputWithPast,
logger, logger,
) )
from transformers.utils import ( from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling, is_torchdynamo_compiling,
replace_return_docstrings,
) )
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward( def cce_forward(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,
@@ -170,10 +162,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal( def cce_forward_multimodal(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,

View File

@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast, CausalLMOutputWithPast,
) )
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
KwargsForCausalLM, KwargsForCausalLM,
) )
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple from transformers.utils.generic import can_return_tuple
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple @can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward( def cce_forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,

View File

@@ -16,22 +16,12 @@ from torch import nn
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama4.modeling_llama4 import ( from transformers.models.llama4.modeling_llama4 import (
_CONFIG_FOR_DOC,
LLAMA4_INPUTS_DOCSTRING,
Llama4CausalLMOutputWithPast, Llama4CausalLMOutputWithPast,
) )
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward( def cce_forward(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,
@@ -160,9 +150,6 @@ def cce_forward(
) )
@replace_return_docstrings(
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal( def cce_forward_multimodal(
self, self,
input_ids: torch.LongTensor | None = None, # type: ignore input_ids: torch.LongTensor | None = None, # type: ignore

View File

@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
Mistral3CausalLMOutputWithPast, Mistral3CausalLMOutputWithPast,
) )
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
_CONFIG_FOR_DOC,
MISTRAL_INPUTS_DOCSTRING,
KwargsForCausalLM, KwargsForCausalLM,
) )
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.utils import ( from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling, is_torchdynamo_compiling,
replace_return_docstrings,
) )
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward( def cce_forward(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,

View File

@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
apply_lce, apply_lce,
) )
from transformers.models.qwen2_moe.modeling_qwen2_moe import ( from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_CONFIG_FOR_DOC,
QWEN2MOE_INPUTS_DOCSTRING,
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
load_balancing_loss_func, load_balancing_loss_func,
) )
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple from transformers.utils.generic import can_return_tuple
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple @can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
) )
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import ( from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast, Qwen2VLCausalLMOutputWithPast,
) )
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal( def cce_forward_multimodal(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,

View File

@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
TransformersModelT, TransformersModelT,
apply_lce, apply_lce,
) )
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import ( from transformers.models.qwen3_moe.modeling_qwen3_moe import (
_CONFIG_FOR_DOC,
QWEN3_MOE_INPUTS_DOCSTRING,
KwargsForCausalLM, KwargsForCausalLM,
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
load_balancing_loss_func, load_balancing_loss_func,
) )
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple from transformers.utils.generic import can_return_tuple
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple @can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
# @replace_return_docstrings(
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
# )
def lce_forward( def lce_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,

View File

@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import ( from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache, HybridMambaAttentionDynamicCache,
load_balancing_loss_func, load_balancing_loss_func,
) )
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward( def lce_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,

View File

@@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union
import torch import torch
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import ( from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast, Gemma3CausalLMOutputWithPast,
logger, logger,
) )
from transformers.utils import ( from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling, is_torchdynamo_compiling,
replace_return_docstrings,
) )
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def new_forward( def new_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,

View File

@@ -1,6 +1,7 @@
"""MLFlow module for trainer callbacks""" """MLFlow module for trainer callbacks"""
import logging import logging
import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
def should_log_artifacts() -> bool:
truths = ["TRUE", "1", "YES"]
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow""" """Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
): ):
if is_main_process(): if is_main_process():
try: try:
with NamedTemporaryFile( if should_log_artifacts():
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" with NamedTemporaryFile(
) as temp_file: mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
copyfile(self.axolotl_config_path, temp_file.name) ) as temp_file:
mlflow.log_artifact(temp_file.name, artifact_path="") copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
else:
LOG.info( LOG.info(
"The Axolotl config has been saved to the MLflow artifacts." "Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
) )
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

View File

@@ -72,6 +72,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
data_set = data_set.map( data_set = data_set.map(
ds_transform_fn, ds_transform_fn,
desc="Mapping RL Dataset", desc="Mapping RL Dataset",
num_proc=cfg.dataset_processes,
**map_kwargs, **map_kwargs,
) )

View File

@@ -484,7 +484,7 @@ def get_dataset_wrapper(
} }
LOG.info( LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
) )
if ( if (

View File

@@ -629,6 +629,49 @@ class ModelLoader:
) )
if self.cfg.flash_attention: if self.cfg.flash_attention:
use_fa3 = False
if self.cfg.use_flash_attention_3 is True:
use_fa3 = True
elif self.cfg.use_flash_attention_3 == "auto":
if torch.cuda.get_device_capability() >= (9, 0):
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if not importlib.util.find_spec("flash_attn_interface"):
use_fa3 = False
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
# this can happen when use_flash_attention_3 is explicity set to True
# and flash_attn_interface is not installed
raise ModuleNotFoundError(
"Please install the flash_attn_interface library to use Flash Attention 3.x"
)
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_func_v3(*args, **kwargs)[0]
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3_wrapper
)
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3_wrapper
)
LOG.info("Switched to Flash Attention v3")
self.patch_attention() self.patch_attention()
if self.cfg.sample_packing and self.cfg.s2_attention: if self.cfg.sample_packing and self.cfg.s2_attention:
@@ -699,6 +742,7 @@ class ModelLoader:
patch_mllama() patch_mllama()
# TODO deprecate soon
if self.model_config.model_type == "btlm": if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import ( from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn, replace_btlm_attn_with_flash_attn,
@@ -706,6 +750,7 @@ class ModelLoader:
replace_btlm_attn_with_flash_attn(self.cfg.base_model) replace_btlm_attn_with_flash_attn(self.cfg.base_model)
# TODO deprecate soon
if ( if (
self.model_config.model_type == "stablelm_epoch" self.model_config.model_type == "stablelm_epoch"
and self.cfg.sample_packing and self.cfg.sample_packing

View File

@@ -233,6 +233,7 @@ class AxolotlInputConfig(
flash_attn_fuse_qkv: bool | None = None flash_attn_fuse_qkv: bool | None = None
flash_attn_fuse_mlp: bool | None = None flash_attn_fuse_mlp: bool | None = None
flash_optimum: bool | None = None flash_optimum: bool | None = None
use_flash_attention_3: Literal["auto"] | bool | None = None
eager_attention: bool | None = None eager_attention: bool | None = None

View File

@@ -421,6 +421,7 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
import transformers.modeling_flash_attention_utils
from transformers import Trainer from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention, LlamaAttention,
@@ -434,6 +435,19 @@ def cleanup_monkeypatches():
Trainer._inner_training_loop # pylint: disable=protected-access Trainer._inner_training_loop # pylint: disable=protected-access
) )
original_trainer_training_step = Trainer.training_step original_trainer_training_step = Trainer.training_step
original_fa_func = None
original_fa_varlen_func = None
if (
importlib.util.find_spec("flash_attn")
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
and hasattr(
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
)
):
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
@@ -444,6 +458,11 @@ def cleanup_monkeypatches():
original_trainer_inner_training_loop original_trainer_inner_training_loop
) )
Trainer.training_step = original_trainer_training_step Trainer.training_step = original_trainer_training_step
if original_fa_func:
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)
# Reset other known monkeypatches # Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [ modules_to_reset: list[tuple[str, list[str]]] = [
@@ -458,6 +477,7 @@ def cleanup_monkeypatches():
("transformers.trainer",), ("transformers.trainer",),
("transformers", ["Trainer"]), ("transformers", ["Trainer"]),
("transformers.loss.loss_utils",), ("transformers.loss.loss_utils",),
("transformers.modeling_flash_attention_utils",),
] ]
for module_name_tuple in modules_to_reset: for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0] module_name = module_name_tuple[0]

View File

@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
""" """
) )
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
[1, 2], [1, 2],
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", "NCCL_P2P_LEVEL": "LOC",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process = start_vllm(
cfg.base_model, cfg.base_model,
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally: finally:
recursive_kill(vllm_process) recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
[1, 2], [1, 2],
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process = start_vllm(
cfg.base_model, cfg.base_model,

View File

@@ -101,7 +101,13 @@ class TestMultiGPULlama:
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 2],
) )
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps): @pytest.mark.parametrize(
"use_flash_attention_3",
[False, "auto"],
)
def test_lora_ddp_packed(
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -138,6 +144,7 @@ class TestMultiGPULlama:
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True, "bf16": True,
"use_flash_attention_3": use_flash_attention_3,
} }
) )

View File

@@ -4,7 +4,6 @@ E2E tests for packed training
import logging import logging
import os import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -14,18 +13,17 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir from .utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase): class TestPackedLlama:
""" """
Test case for Packed training of llama models Test case for Packed training of llama models
""" """
@with_temp_dir
def test_loss_packed(self, temp_dir): def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(