Compare commits
23 Commits
rl-trainer
...
fa3-hopper
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bdf4b1c23 | ||
|
|
d6f64a3684 | ||
|
|
0735454782 | ||
|
|
bb6464c4c6 | ||
|
|
323a9cb153 | ||
|
|
b22150751f | ||
|
|
8c4bc59bfc | ||
|
|
a064f1c9b4 | ||
|
|
fb5ef6d445 | ||
|
|
34b68ddaae | ||
|
|
9a3d0c919b | ||
|
|
bd34d0b861 | ||
|
|
37220ab90a | ||
|
|
e1b74d710b | ||
|
|
79daf5b934 | ||
|
|
ddd7c55576 | ||
|
|
65c6c98a76 | ||
|
|
4ef2e8293f | ||
|
|
c126d5cd04 | ||
|
|
9b0be4f15c | ||
|
|
a27b909c5c | ||
|
|
6cb07b9d12 | ||
|
|
288653adb6 |
11
.github/workflows/base.yml
vendored
11
.github/workflows/base.yml
vendored
@@ -47,11 +47,18 @@ jobs:
|
|||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
suffix: "-hopper"
|
||||||
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -87,7 +94,7 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
|||||||
11
.github/workflows/multi-gpu-e2e.yml
vendored
11
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -32,21 +32,25 @@ jobs:
|
|||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
|
suffix: "-hopper"
|
||||||
|
num_gpus: 2
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
@@ -68,7 +72,6 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
|
||||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
|
||||||
|
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=90 * 60,
|
timeout=90 * 60,
|
||||||
cpu=8.0,
|
cpu=16.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG CUDA_VERSION="11.8.0"
|
ARG CUDA_VERSION="12.4.1"
|
||||||
ARG CUDNN_VERSION="8"
|
ARG CUDNN_VERSION=""
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
|
|||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.10"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.5.1"
|
||||||
ARG CUDA="118"
|
ARG CUDA="124"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
@@ -38,6 +38,10 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
|
||||||
|
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
pip3 install flash-attn==2.7.4.post1; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -633,7 +633,9 @@ weight_decay:
|
|||||||
# adamw hyperparams
|
# adamw hyperparams
|
||||||
adam_beta1:
|
adam_beta1:
|
||||||
adam_beta2:
|
adam_beta2:
|
||||||
|
adam_beta3: # only used for CAME Optimizer
|
||||||
adam_epsilon:
|
adam_epsilon:
|
||||||
|
adam_epsilon2: # only used for CAME Optimizer
|
||||||
# Gradient clipping max norm
|
# Gradient clipping max norm
|
||||||
max_grad_norm:
|
max_grad_norm:
|
||||||
|
|
||||||
|
|||||||
@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||||
if self.cfg.adam_beta2:
|
if self.cfg.adam_beta2:
|
||||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||||
|
if self.cfg.adam_beta3:
|
||||||
|
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
|
||||||
if self.cfg.adam_epsilon:
|
if self.cfg.adam_epsilon:
|
||||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||||
|
if self.cfg.adam_epsilon2:
|
||||||
|
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
|
||||||
if self.cfg.max_grad_norm:
|
if self.cfg.max_grad_norm:
|
||||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||||
|
|
||||||
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
||||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
||||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
|
||||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
||||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
||||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||||
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
if self.cfg.rl is not RLType.GRPO:
|
||||||
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
|
|||||||
@@ -156,9 +156,6 @@ class AxolotlTrainer(
|
|||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
and sample packing cases.
|
and sample packing cases.
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_dataset: Evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
@@ -240,6 +237,9 @@ class AxolotlTrainer(
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,33 @@
|
|||||||
"""DPO trainer for Axolotl"""
|
"""
|
||||||
|
DPO trainer for axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import random
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
import wandb
|
||||||
|
from accelerate import PartialState
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import DataLoader
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
BaseImageProcessor,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
Trainer,
|
Trainer,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_utils import EvalLoopOutput
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
|
||||||
|
from trl.trainer.utils import log_table_to_comet_experiment
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
RngLoaderMixin,
|
|
||||||
SchedulerMixin,
|
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -29,10 +37,10 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(
|
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||||
RngLoaderMixin, SchedulerMixin, SequenceParallelMixin, DPOTrainer
|
"""
|
||||||
):
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""Extend the base DPOTrainer for axolotl helpers"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
@@ -87,6 +95,64 @@ class AxolotlDPOTrainer(
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
||||||
|
def _prepare_dataset(
|
||||||
|
self,
|
||||||
|
dataset: Union[Dataset, IterableDataset],
|
||||||
|
processing_class: Union[
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
BaseImageProcessor,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
ProcessorMixin,
|
||||||
|
],
|
||||||
|
args: DPOConfig,
|
||||||
|
dataset_name: str,
|
||||||
|
) -> Union[Dataset, IterableDataset]:
|
||||||
|
# Build the kwargs for the `map` function
|
||||||
|
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
|
||||||
|
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||||
|
map_kwargs["num_proc"] = args.dataset_num_proc
|
||||||
|
|
||||||
|
with PartialState().main_process_first():
|
||||||
|
# Extract prompt if needed
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
|
||||||
|
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
|
||||||
|
|
||||||
|
# Apply the chat template if needed
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
||||||
|
dataset = dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize the dataset
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
self.tokenize_row if not self.is_vision_model else self.process_row,
|
||||||
|
remove_columns=["chosen", "rejected"],
|
||||||
|
fn_kwargs={
|
||||||
|
"processing_class": processing_class,
|
||||||
|
"max_prompt_length": args.max_prompt_length,
|
||||||
|
"max_completion_length": args.max_completion_length,
|
||||||
|
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
|
||||||
|
"add_special_tokens": False,
|
||||||
|
},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
features,
|
features,
|
||||||
@@ -127,48 +193,68 @@ class AxolotlDPOTrainer(
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Sampler | None:
|
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
||||||
|
def evaluation_loop(
|
||||||
|
self,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
description: str,
|
||||||
|
prediction_loss_only: Optional[bool] = None,
|
||||||
|
ignore_keys: Optional[list[str]] = None,
|
||||||
|
metric_key_prefix: str = "eval",
|
||||||
|
) -> EvalLoopOutput:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sequence
|
Overriding built-in evaluation loop to store metrics for each batch.
|
||||||
parallelism, sample packing, and curriculum sampling (sequential).
|
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
||||||
|
|
||||||
Returns:
|
Works both with or without labels.
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
"""
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
# Sample and save to game log if requested (for one batch to save time)
|
||||||
import ipdb
|
if self.generate_during_eval:
|
||||||
|
# Generate random indices within the range of the total number of samples
|
||||||
|
num_samples = len(dataloader.dataset)
|
||||||
|
random_indices = random.sample(
|
||||||
|
range(num_samples), k=self.args.eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
ipdb.set_trace()
|
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
||||||
dist.barrier()
|
random_batch_dataset = dataloader.dataset.select(random_indices)
|
||||||
if dist.get_rank() == 1:
|
random_batch = self.data_collator(random_batch_dataset)
|
||||||
import ipdb
|
random_batch = self._prepare_inputs(random_batch)
|
||||||
|
|
||||||
ipdb.set_trace()
|
policy_output_decoded, ref_output_decoded = (
|
||||||
dist.barrier()
|
self.generate_from_model_and_ref(self.model, random_batch)
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
table = pd.DataFrame(
|
||||||
return self._sp_get_train_sampler(self.train_dataset)
|
columns=["Prompt", "Policy", "Ref Model"],
|
||||||
|
data=[
|
||||||
|
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
||||||
|
for prompt, pol, ref in zip(
|
||||||
|
random_batch_dataset["prompt"],
|
||||||
|
policy_output_decoded,
|
||||||
|
ref_output_decoded,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
|
||||||
|
wandb.log({"game_log": wandb.Table(data=table)})
|
||||||
|
|
||||||
return super()._get_train_sampler()
|
if "comet_ml" in self.args.report_to:
|
||||||
|
log_table_to_comet_experiment(
|
||||||
|
name="game_log.csv",
|
||||||
|
table=table,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
# Base evaluation
|
||||||
"""
|
initial_output = super( # pylint: disable=bad-super-call
|
||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
DPOTrainer, self
|
||||||
and sample packing cases.
|
).evaluation_loop(
|
||||||
|
dataloader,
|
||||||
|
description,
|
||||||
|
prediction_loss_only,
|
||||||
|
ignore_keys,
|
||||||
|
metric_key_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
return initial_output
|
||||||
eval_dataset: Evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
return self._sp_get_eval_sampler(eval_dataset)
|
|
||||||
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -14,7 +13,7 @@ from accelerate.utils import (
|
|||||||
broadcast_object_list,
|
broadcast_object_list,
|
||||||
gather,
|
gather,
|
||||||
gather_object,
|
gather_object,
|
||||||
is_peft_model,
|
is_peft_available,
|
||||||
)
|
)
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -30,15 +29,13 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_peft_available
|
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.data_utils import (
|
from trl.data_utils import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
is_conversational,
|
is_conversational,
|
||||||
maybe_apply_chat_template,
|
maybe_apply_chat_template,
|
||||||
)
|
)
|
||||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
from trl.extras.profiling import profiling_context
|
||||||
from trl.import_utils import is_deepspeed_available
|
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
from trl.trainer.grpo_config import GRPOConfig
|
from trl.trainer.grpo_config import GRPOConfig
|
||||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||||
@@ -52,62 +49,12 @@ if is_peft_available():
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from peft import PeftConfig
|
from peft import PeftConfig
|
||||||
|
|
||||||
if is_deepspeed_available():
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
@profiling_decorator
|
|
||||||
def _move_model_to_vllm(self):
|
|
||||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
|
||||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
|
||||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
|
||||||
gather_if_zero3 = (
|
|
||||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_peft_model(self.model):
|
|
||||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
|
||||||
# adapters in a sharded manner is not supported.
|
|
||||||
with gather_if_zero3(list(self.model.parameters())):
|
|
||||||
self.model.merge_adapter()
|
|
||||||
|
|
||||||
# Update vLLM weights while parameters are gathered
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
|
||||||
name = (
|
|
||||||
name.removeprefix("base_model.model.")
|
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", "")
|
|
||||||
)
|
|
||||||
if self.model.prefix in name:
|
|
||||||
continue
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
|
||||||
if "original_module" in name:
|
|
||||||
continue
|
|
||||||
name = name.replace("modules_to_save.default.", "")
|
|
||||||
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Unmerge adapters while parameters are still gathered
|
|
||||||
self.model.unmerge_adapter()
|
|
||||||
# Parameters will automatically be repartitioned when exiting the context
|
|
||||||
else:
|
|
||||||
# For non-PEFT models, simply gather and update each parameter individually.
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
with gather_if_zero3([param]):
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Reset cache on main process
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.reset_prefix_cache()
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
@@ -266,6 +213,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
|||||||
@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
adam_beta3: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
adam_epsilon2: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# multi-modal section
|
# multi-modal section
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""MLFlow module for trainer callbacks"""
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
|
|||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
|
def should_log_artifacts() -> bool:
|
||||||
|
truths = ["TRUE", "1", "YES"]
|
||||||
|
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
"""Callback to save axolotl config to mlflow"""
|
"""Callback to save axolotl config to mlflow"""
|
||||||
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
with NamedTemporaryFile(
|
if should_log_artifacts():
|
||||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
with NamedTemporaryFile(
|
||||||
) as temp_file:
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
copyfile(self.axolotl_config_path, temp_file.name)
|
) as temp_file:
|
||||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the MLflow artifacts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"The Axolotl config has been saved to the MLflow artifacts."
|
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -33,7 +32,7 @@ def apply_sequence_parallelism(
|
|||||||
to only keep the last N tokens in the sequence during generation.
|
to only keep the last N tokens in the sequence during generation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Dictionary of model arguments (e.g., input_ids, attention_mask, etc.).
|
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||||
local_rank: Local rank in the sequence parallel group.
|
local_rank: Local rank in the sequence parallel group.
|
||||||
local_world_size: World size of the sequence parallel group.
|
local_world_size: World size of the sequence parallel group.
|
||||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||||
@@ -207,26 +206,12 @@ class SequenceParallelContextManager:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Forward pre-hook to apply sequence parallelism
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
# Convert all args to kwargs using the model's forward function signature
|
# Apply sequence parallelism to kwargs and get original sequence length and padding info
|
||||||
updated_kwargs = kwargs.copy()
|
kwargs, self.original_seq_len, self.pad_len = (
|
||||||
|
self.apply_sequence_parallelism(batch=kwargs)
|
||||||
# Get parameter names from the model's forward function
|
|
||||||
forward_params = list(
|
|
||||||
inspect.signature(self.models[0].forward).parameters.keys()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map args to their parameter names
|
return args, kwargs
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if i < len(forward_params):
|
|
||||||
param_name = forward_params[i]
|
|
||||||
updated_kwargs[param_name] = arg
|
|
||||||
|
|
||||||
# Apply sequence parallelism to empty args and updated kwargs
|
|
||||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
|
||||||
self.apply_sequence_parallelism(updated_kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
return (), updated_kwargs
|
|
||||||
|
|
||||||
# Forward post-hook to gather outputs
|
# Forward post-hook to gather outputs
|
||||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||||
|
|||||||
@@ -629,6 +629,49 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
|
use_fa3 = False
|
||||||
|
if self.cfg.use_flash_attention_3 is True:
|
||||||
|
use_fa3 = True
|
||||||
|
elif self.cfg.use_flash_attention_3 == "auto":
|
||||||
|
if torch.cuda.get_device_capability() >= (9, 0):
|
||||||
|
# FA3 is only available on Hopper GPUs and newer
|
||||||
|
use_fa3 = True
|
||||||
|
if not importlib.util.find_spec("flash_attn_interface"):
|
||||||
|
use_fa3 = False
|
||||||
|
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
|
||||||
|
# this can happen when use_flash_attention_3 is explicity set to True
|
||||||
|
# and flash_attn_interface is not installed
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install the flash_attn_interface library to use Flash Attention 3.x"
|
||||||
|
)
|
||||||
|
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||||
|
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
||||||
|
from flash_attn_interface import (
|
||||||
|
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
|
return flash_attn_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
|
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||||
|
flash_attn_func_v3_wrapper
|
||||||
|
)
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
flash_attn_varlen_func_v3_wrapper
|
||||||
|
)
|
||||||
|
LOG.info("Switched to Flash Attention v3")
|
||||||
|
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
@@ -699,6 +742,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_mllama()
|
patch_mllama()
|
||||||
|
|
||||||
|
# TODO deprecate soon
|
||||||
if self.model_config.model_type == "btlm":
|
if self.model_config.model_type == "btlm":
|
||||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||||
replace_btlm_attn_with_flash_attn,
|
replace_btlm_attn_with_flash_attn,
|
||||||
@@ -706,6 +750,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
||||||
|
|
||||||
|
# TODO deprecate soon
|
||||||
if (
|
if (
|
||||||
self.model_config.model_type == "stablelm_epoch"
|
self.model_config.model_type == "stablelm_epoch"
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
|
|||||||
@@ -233,6 +233,7 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_qkv: bool | None = None
|
flash_attn_fuse_qkv: bool | None = None
|
||||||
flash_attn_fuse_mlp: bool | None = None
|
flash_attn_fuse_mlp: bool | None = None
|
||||||
flash_optimum: bool | None = None
|
flash_optimum: bool | None = None
|
||||||
|
use_flash_attention_3: Literal["auto"] | bool | None = None
|
||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ def temp_dir():
|
|||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@@ -434,6 +435,19 @@ def cleanup_monkeypatches():
|
|||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
original_trainer_training_step = Trainer.training_step
|
original_trainer_training_step = Trainer.training_step
|
||||||
|
original_fa_func = None
|
||||||
|
original_fa_varlen_func = None
|
||||||
|
if (
|
||||||
|
importlib.util.find_spec("flash_attn")
|
||||||
|
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
|
||||||
|
and hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||||
|
original_fa_varlen_func = (
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||||
|
)
|
||||||
# monkey patches can happen inside the tests
|
# monkey patches can happen inside the tests
|
||||||
yield
|
yield
|
||||||
# Reset LlamaFlashAttention2 forward
|
# Reset LlamaFlashAttention2 forward
|
||||||
@@ -444,6 +458,11 @@ def cleanup_monkeypatches():
|
|||||||
original_trainer_inner_training_loop
|
original_trainer_inner_training_loop
|
||||||
)
|
)
|
||||||
Trainer.training_step = original_trainer_training_step
|
Trainer.training_step = original_trainer_training_step
|
||||||
|
if original_fa_func:
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
original_fa_varlen_func
|
||||||
|
)
|
||||||
|
|
||||||
# Reset other known monkeypatches
|
# Reset other known monkeypatches
|
||||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
@@ -458,6 +477,7 @@ def cleanup_monkeypatches():
|
|||||||
("transformers.trainer",),
|
("transformers.trainer",),
|
||||||
("transformers", ["Trainer"]),
|
("transformers", ["Trainer"]),
|
||||||
("transformers.loss.loss_utils",),
|
("transformers.loss.loss_utils",),
|
||||||
|
("transformers.modeling_flash_attention_utils",),
|
||||||
]
|
]
|
||||||
for module_name_tuple in modules_to_reset:
|
for module_name_tuple in modules_to_reset:
|
||||||
module_name = module_name_tuple[0]
|
module_name = module_name_tuple[0]
|
||||||
|
|||||||
@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
recursive_kill(vllm_process)
|
recursive_kill(vllm_process)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
|
|||||||
@@ -101,7 +101,13 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
@pytest.mark.parametrize(
|
||||||
|
"use_flash_attention_3",
|
||||||
|
[False, "auto"],
|
||||||
|
)
|
||||||
|
def test_lora_ddp_packed(
|
||||||
|
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -138,6 +144,7 @@ class TestMultiGPULlama:
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
|
"use_flash_attention_3": use_flash_attention_3,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for packed training
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
@@ -14,18 +13,17 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_tensorboard, with_temp_dir
|
from .utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestPackedLlama(unittest.TestCase):
|
class TestPackedLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Packed training of llama models
|
Test case for Packed training of llama models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_loss_packed(self, temp_dir):
|
def test_loss_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user