Compare commits

..

1 Commits

Author SHA1 Message Date
Dan Saunders
c25990fd4f additional RL trainers SP support 2025-05-16 18:19:36 +00:00
19 changed files with 160 additions and 296 deletions

View File

@@ -47,18 +47,11 @@ jobs:
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
suffix: "-hopper"
torch_cuda_arch_list: "9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -94,7 +87,7 @@ jobs:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}

View File

@@ -32,25 +32,21 @@ jobs:
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
suffix: "-hopper"
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -72,6 +68,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |

View File

@@ -32,11 +32,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
cpu=16.0,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)

View File

@@ -1,5 +1,5 @@
ARG CUDA_VERSION="12.4.1"
ARG CUDNN_VERSION=""
ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.5.1"
ARG CUDA="124"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
@@ -38,10 +38,6 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -633,9 +633,7 @@ weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:

View File

@@ -387,12 +387,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_beta3:
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.adam_epsilon2:
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
@@ -717,7 +713,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
@@ -1174,8 +1170,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
if self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs

View File

@@ -156,6 +156,9 @@ class AxolotlTrainer(
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Args:
eval_dataset: Evaluation dataset.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
@@ -237,9 +240,6 @@ class AxolotlTrainer(
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader

View File

@@ -1,33 +1,25 @@
"""
DPO trainer for axolotl
"""
"""DPO trainer for Axolotl"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -37,10 +29,10 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, SequenceParallelMixin, DPOTrainer
):
"""Extend the base DPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "dpo"]
@@ -95,64 +87,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -193,68 +127,48 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
def _get_train_sampler(self) -> Sampler | None:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Works both with or without labels.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
import torch.distributed as dist
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
if dist.get_rank() == 0:
import ipdb
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
ipdb.set_trace()
dist.barrier()
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if self.args.sequence_parallel_degree > 1:
return self._sp_get_train_sampler(self.train_dataset)
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
return super()._get_train_sampler()
# Base evaluation
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
return initial_output
Args:
eval_dataset: Evaluation dataset.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if self.args.sequence_parallel_degree > 1:
return self._sp_get_eval_sampler(eval_dataset)
return super()._get_eval_sampler(eval_dataset)

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
@@ -13,7 +14,7 @@ from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_available,
is_peft_model,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -29,13 +30,15 @@ from transformers import (
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -49,12 +52,62 @@ if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
@@ -213,9 +266,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader

View File

@@ -227,19 +227,6 @@ class AxolotlTrainingMixins:
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -1,7 +1,6 @@
"""MLFlow module for trainer callbacks"""
import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
@@ -17,11 +16,6 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks")
def should_log_artifacts() -> bool:
truths = ["TRUE", "1", "YES"]
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
@@ -38,18 +32,13 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
):
if is_main_process():
try:
if should_log_artifacts():
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
else:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

View File

@@ -1,6 +1,7 @@
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
import functools
import inspect
import torch
import torch.distributed as dist
@@ -32,7 +33,7 @@ def apply_sequence_parallelism(
to only keep the last N tokens in the sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
batch: Dictionary of model arguments (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
@@ -206,12 +207,26 @@ class SequenceParallelContextManager:
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs and get original sequence length and padding info
kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(batch=kwargs)
# Convert all args to kwargs using the model's forward function signature
updated_kwargs = kwargs.copy()
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
return args, kwargs
# Map args to their parameter names
for i, arg in enumerate(args):
if i < len(forward_params):
param_name = forward_params[i]
updated_kwargs[param_name] = arg
# Apply sequence parallelism to empty args and updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs)
)
return (), updated_kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:

View File

@@ -629,49 +629,6 @@ class ModelLoader:
)
if self.cfg.flash_attention:
use_fa3 = False
if self.cfg.use_flash_attention_3 is True:
use_fa3 = True
elif self.cfg.use_flash_attention_3 == "auto":
if torch.cuda.get_device_capability() >= (9, 0):
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if not importlib.util.find_spec("flash_attn_interface"):
use_fa3 = False
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
# this can happen when use_flash_attention_3 is explicity set to True
# and flash_attn_interface is not installed
raise ModuleNotFoundError(
"Please install the flash_attn_interface library to use Flash Attention 3.x"
)
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_func_v3(*args, **kwargs)[0]
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3_wrapper
)
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3_wrapper
)
LOG.info("Switched to Flash Attention v3")
self.patch_attention()
if self.cfg.sample_packing and self.cfg.s2_attention:
@@ -742,7 +699,6 @@ class ModelLoader:
patch_mllama()
# TODO deprecate soon
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
@@ -750,7 +706,6 @@ class ModelLoader:
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
# TODO deprecate soon
if (
self.model_config.model_type == "stablelm_epoch"
and self.cfg.sample_packing

View File

@@ -233,7 +233,6 @@ class AxolotlInputConfig(
flash_attn_fuse_qkv: bool | None = None
flash_attn_fuse_mlp: bool | None = None
flash_optimum: bool | None = None
use_flash_attention_3: Literal["auto"] | bool | None = None
eager_attention: bool | None = None

View File

@@ -421,7 +421,6 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
import transformers.modeling_flash_attention_utils
from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
@@ -435,19 +434,6 @@ def cleanup_monkeypatches():
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
original_fa_func = None
original_fa_varlen_func = None
if (
importlib.util.find_spec("flash_attn")
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
and hasattr(
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
)
):
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
@@ -458,11 +444,6 @@ def cleanup_monkeypatches():
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
if original_fa_func:
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
@@ -477,7 +458,6 @@ def cleanup_monkeypatches():
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",),
("transformers.modeling_flash_attention_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]

View File

@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -230,6 +231,8 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,
@@ -263,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -321,6 +325,8 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,

View File

@@ -101,13 +101,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"use_flash_attention_3",
[False, "auto"],
)
def test_lora_ddp_packed(
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
):
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -144,7 +138,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"use_flash_attention_3": use_flash_attention_3,
}
)

View File

@@ -4,6 +4,7 @@ E2E tests for packed training
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
@@ -13,17 +14,18 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama:
class TestPackedLlama(unittest.TestCase):
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(