Compare commits

..

10 Commits

Author SHA1 Message Date
Wing Lian
37a66e6866 multigpu longer timeout 2025-04-09 01:54:35 -04:00
Wing Lian
9f69597a5f upgrade transformers to 4.51.1 2025-04-09 00:20:50 -04:00
Wing Lian
0dac2ddeac Llama4 linearized (#2502)
* llama4 support for linearized experts

* clean up fsdp2 sharding to prevent hang

* add yaml config

* cleanup example [skip ci]
2025-04-07 20:47:00 -04:00
NanoCode012
a6c03217f5 feat: add llama4 CCE (#2498)
* feat: add llama4 CCE

* fix: update model support list doc

* feat: include llama4_text
2025-04-07 17:12:28 -04:00
Dan Saunders
59cd472504 SP cu_seqlens fix, refactor (#2495)
* working on masking fix

* refactor and fix multipack seqlens

* pre-commit fix

* adding smoke test

* using existing packed seqlens util

* log warning re: logged losses / gradient scaling per rank
2025-04-07 14:47:57 -04:00
NanoCode012
9b89591ead Feat: Add doc on loading datasets and support for Azure/OCI (#2482)
* fix: remove unused config

* feat: add doc on dataset loading

* feat: enable azure and oci remote file system

* feat: add adlfs and ocifs to requirements

* fix: add links between dataset formats and dataset loading

* fix: remove unused condition

* Revert "fix: remove unused condition"

This reverts commit 5fe13be73e.
2025-04-07 12:41:13 -04:00
NanoCode012
31498d0230 fix(doc): clarify roles mapping in chat_template (#2490) [skip ci] 2025-04-07 12:40:32 -04:00
NanoCode012
d25daebea9 fix: duplicate llama4 chattemplate enum (#2500)
* fix: duplicate llama4 chattemplate enum

* fix: duplicate chat_template string
2025-04-07 12:39:19 -04:00
NanoCode012
e0e5d9b1d6 feat: add llama4 multimodal (#2499)
* feat: add llama4 multimodal

* feat: add torchvision to base docker

* just use latest torchvision

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-07 10:49:29 -04:00
Wing Lian
8bbad21bfd llama4 support (#2493)
* llama4 support

* add xet support [skip ci]

* be flexible on transformers version and skip test on version

* don't use deepspeed for the fix_untrained_tokens test

* reordering to trigger torch 2.6.0 tests first

* slightly smaller train set

* use 4.51.0 for now

* remove stray print, add llama4 chat template to schema, bump peft to 0.15.1

* patches to make llama4 performant

* add preliminary fp8 support
2025-04-07 10:49:15 -04:00
33 changed files with 1313 additions and 279 deletions

View File

@@ -231,6 +231,7 @@ website:
- docs/reward_modelling.qmd
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- section: "Core Concepts"
contents:

View File

@@ -68,7 +68,7 @@ def run_cmd(cmd: str, run_folder: str):
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=90 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,

View File

@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -109,7 +109,7 @@ datasets:
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
@@ -165,7 +165,9 @@ datasets:
content: value
# ...
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
# Optional[Dict[str, List]]. Roles mapping in the messages.
# The format is {target_role: [source_roles]}. All source roles will be mapped to the target role.
# The default is:
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant"]

View File

@@ -13,6 +13,13 @@ As there are a lot of available options in Axolotl, this guide aims to provide a
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
::: {.callout-tip}
This guide will mainly use JSONL as an introduction. Please refer to the [dataset loading docs](../dataset_loading.qmd) to understand how to load datasets from other sources.
For `pretraining_dataset:` specifically, please refer to the [Pre-training section](#pre-training).
:::
## Pre-training
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.

276
docs/dataset_loading.qmd Normal file
View File

@@ -0,0 +1,276 @@
---
title: Dataset Loading
description: Understanding how to load datasets from different sources
back-to-top-navigation: true
toc: true
toc-depth: 5
---
## Overview
Datasets can be loaded in a number of different ways depending on the how it is saved (the extension of the file) and where it is stored.
## Loading Datasets
We use the `datasets` library to load datasets and a mix of `load_dataset` and `load_from_disk` to load them.
You may recognize the similar named configs between `load_dataset` and the `datasets` section of the config file.
```yaml
datasets:
- path:
name:
data_files:
split:
revision:
trust_remote_code:
```
::: {.callout-tip}
Do not feel overwhelmed by the number of options here. A lot of them are optional. In fact, the most common config to use would be `path` and sometimes `data_files`.
:::
This matches the API of [`datasets.load_dataset`](https://github.com/huggingface/datasets/blob/0b5998ac62f08e358f8dcc17ec6e2f2a5e9450b6/src/datasets/load.py#L1838-L1858), so if you're familiar with that, you will feel right at home.
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
For full details on the config, see [config.qmd](config.qmd).
::: {.callout-note}
You can set multiple datasets in the config file by more than one entry under `datasets`.
```yaml
datasets:
- path: /path/to/your/dataset
- path: /path/to/your/other/dataset
```
:::
### Local dataset
#### Files
Usually, to load a JSON file, you would do something like this:
```python
from datasets import load_dataset
dataset = load_dataset("json", data_files="data.json")
```
Which translates to the following config:
```yaml
datasets:
- path: json
data_files: /path/to/your/file.jsonl
```
However, to make things easier, we have added a few shortcuts for loading local dataset files.
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
```yaml
datasets:
- path: /path/to/your/file.jsonl
ds_type: json
```
This works for CSV, JSON, Parquet, and Arrow files.
::: {.callout-tip}
If `path` points to a file and `ds_type` is not specified, we will automatically infer the dataset type from the file extension, so you could omit `ds_type` if you'd like.
:::
#### Directory
If you're loading a directory, you can point the `path` to the directory.
Then, you have two options:
##### Loading entire directory
You do not need any additional configs.
We will attempt to load in the following order:
- datasets saved with `datasets.save_to_disk`
- loading entire directory of files (such as with parquet/arrow files)
```yaml
datasets:
- path: /path/to/your/directory
```
##### Loading specific files in directory
Provide `data_files` with a list of files to load.
```yaml
datasets:
# single file
- path: /path/to/your/directory
ds_type: csv
data_files: file1.csv
# multiple files
- path: /path/to/your/directory
ds_type: json
data_files:
- file1.jsonl
- file2.jsonl
# multiple files for parquet
- path: /path/to/your/directory
ds_type: parquet
data_files:
- file1.parquet
- file2.parquet
```
### HuggingFace Hub
The method you use to load the dataset depends on how the dataset was created, whether a folder was uploaded directly or a HuggingFace Dataset was pushed.
::: {.callout-note}
If you're using a private dataset, you will need to enable the `hf_use_auth_token` flag in the root-level of the config file.
:::
#### Folder uploaded
This would mean that the dataset is a single file or file(s) uploaded to the Hub.
```yaml
datasets:
- path: org/dataset-name
data_files:
- file1.jsonl
- file2.jsonl
```
#### HuggingFace Dataset
This means that the dataset is created as a HuggingFace Dataset and pushed to the Hub via `datasets.push_to_hub`.
```yaml
datasets:
- path: org/dataset-name
```
::: {.callout-note}
There are some other configs which may be required like `name`, `split`, `revision`, `trust_remote_code`, etc depending on the dataset.
:::
### Remote Filesystems
Via the `storage_options` config under `load_dataset`, you can load datasets from remote filesystems like S3, GCS, Azure, and OCI.
::: {.callout-warning}
This is currently experimental. Please let us know if you run into any issues!
:::
The only difference between the providers is that you need to prepend the path with the respective protocols.
```yaml
datasets:
# Single file
- path: s3://bucket-name/path/to/your/file.jsonl
# Directory
- path: s3://bucket-name/path/to/your/directory
```
For directory, we load via `load_from_disk`.
#### S3
Prepend the path with `s3://`.
The credentials are pulled in the following order:
- `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN` environment variables
- from the `~/.aws/credentials` file
- for nodes on EC2, the IAM metadata provider
::: {.callout-note}
We assume you have credentials setup and not using anonymous access. If you want to use anonymous access, let us know! We may have to open a config option for this.
:::
Other environment variables that can be set can be found in [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables)
#### GCS
Prepend the path with `gs://` or `gcs://`.
The credentials are loaded in the following order:
- gcloud credentials
- for nodes on GCP, the google metadata service
- anonymous access
#### Azure
##### Gen 1
Prepend the path with `adl://`.
Ensure you have the following environment variables set:
- `AZURE_STORAGE_TENANT_ID`
- `AZURE_STORAGE_CLIENT_ID`
- `AZURE_STORAGE_CLIENT_SECRET`
##### Gen 2
Prepend the path with `abfs://` or `az://`.
Ensure you have the following environment variables set:
- `AZURE_STORAGE_ACCOUNT_NAME`
- `AZURE_STORAGE_ACCOUNT_KEY`
Other environment variables that can be set can be found in [adlfs docs](https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials)
#### OCI
Prepend the path with `oci://`.
It would attempt to read in the following order:
- `OCIFS_IAM_TYPE`, `OCIFS_CONFIG_LOCATION`, and `OCIFS_CONFIG_PROFILE` environment variables
- when on OCI resource, resource principal
Other environment variables:
- `OCI_REGION_METADATA`
Please see the [ocifs docs](https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables).
### HTTPS
The path should start with `https://`.
```yaml
datasets:
- path: https://path/to/your/dataset/file.jsonl
```
This must be publically accessible.
## Next steps
Now that you know how to load datasets, you can learn more on how to load your specific dataset format into your target output format [dataset formats docs](dataset-formats).

View File

@@ -9,6 +9,7 @@ format:
## Supported Models
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
@@ -63,6 +64,14 @@ base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
chat_template: llama3_2_vision
```
### Llama4 {#sec-llama4}
```yaml
base_model: meta-llama/Llama-4-Scout-17B-16E-Instruct
chat_template: llama4
```
### Pixtral {#sec-pixtral}
```yaml

View File

@@ -0,0 +1,93 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -4,3 +4,5 @@ mypy
types-requests
quartodoc
jupyter
blobfile
tiktoken

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.6
packaging==23.2
peft==0.15.1
transformers==4.51.0
transformers==4.51.1
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
@@ -49,7 +49,8 @@ python-dotenv==1.0.1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
adlfs>=2024.5.0
ocifs==1.3.2
zstandard==0.22.0
fastcore

View File

@@ -235,6 +235,9 @@ class AxolotlTrainer(
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader

View File

@@ -1,34 +1,22 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import logging
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
including creating appropriate samplers, managing data partitioning,
and updating ring flash attention parameters during training.
specifically for creating appropriate data samplers.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -99,84 +87,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn. This is accomplished by using the passed
`input_ids`.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -32,6 +32,9 @@ cut_cross_entropy: true
## Supported Models
- llama
- llama4_text
- llama4
- mllama
- phi3
- gemma
- gemma2

View File

@@ -0,0 +1,414 @@
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama4.modeling_llama4 import (
_CONFIG_FOR_DOC,
LLAMA4_INPUTS_DOCSTRING,
Llama4CausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@replace_return_docstrings(
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor | None = None,
**lm_kwargs,
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != projected_vision_flat.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, projected_vision_flat
) # type: ignore
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**lm_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# TODO: check if need to handle attention_mask
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**lm_kwargs,
)
else:
logits = hidden_states
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
logits.device
)
shift_logits = logits[..., :-1, :][
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = labels[..., 1:][
shift_attention_mask.to(labels.device) != 0
].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(shift_logits.device),
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Llama4CausalLMOutputWithPast(
loss=loss,
logits=logits, # type: ignore # TODO: check if need to create dummy logits
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_llama4_text(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.llama4 import modeling_llama4
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_llama4.Llama4ForCausalLM
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
setattr(
modeling_llama4.Llama4ForCausalLM,
"forward",
cce_forward,
)
return None
def patch_llama4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.llama4 import modeling_llama4
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the language model
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
setattr(
modeling_llama4.Llama4ForConditionalGeneration,
"forward",
cce_forward_multimodal,
)
# patch the causal language model
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
return None

View File

@@ -20,6 +20,10 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral,
patch_mistral3,
@@ -28,6 +32,8 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mlla
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
"llama4": patch_llama4,
"llama4_text": patch_llama4_text,
"mllama": patch_mllama,
"phi3": patch_phi3,
"gemma": patch_gemma,
@@ -60,7 +66,14 @@ def cce_patch(
raise ValueError(f"Unknown {impl=}")
if isinstance(model_type_or_model, transformers.PreTrainedModel):
model_type = model_type_or_model.config.model_type
if hasattr(model_type_or_model, "config"):
model_type = getattr(
getattr(model_type_or_model, "config", None), "model_type", None
)
else:
raise ValueError(
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
)
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
model_type = model_type_or_model.model_type
else:

View File

@@ -0,0 +1,63 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
"""
import logging
import sys
import torch
LOG = logging.getLogger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to load the state dict into
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...")
sharded_sd = model.state_dict()
param_names = sorted(sharded_sd.keys())
for param_name in param_names:
mesh = sharded_sd[param_name].device_mesh
if accelerator.is_main_process:
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
else:
# Prepare a tensor of matching shape and dtype
full_tensor = torch.empty(
sharded_sd[param_name].size(),
device="cuda",
dtype=sharded_sd[param_name].dtype,
)
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd)
def patch_accelerate_fsdp_utils():
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
setattr(
sys.modules["accelerate.utils.fsdp_utils"],
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)

View File

@@ -6,10 +6,12 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__)
@@ -98,3 +100,27 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
batch: A dictionary with a batch of data. May or may not contain `position_ids`
data; if not, we compute it.
"""
from ring_flash_attn import update_ring_flash_attn_params
input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -4,7 +4,7 @@ import importlib
import inspect
import logging
import types
from typing import Type
from typing import Generator, Tuple, Type
import torch
from accelerate.logging import get_logger
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
)
def find_self_attn_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module], None, None]:
# general case of most models
if hasattr(layer, "self_attn"):
if all(
hasattr(layer.self_attn, proj)
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
):
yield layer.self_attn
def find_mlp_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
# general case of most models
if hasattr(layer, "mlp"):
if all(
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
):
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
# llama4 linearized experts
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
mlp = layer.feedforward.shared_expert
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
if all(
hasattr(layer.feedforward.experts, proj)
for proj in ["gate_projs", "up_projs", "down_projs"]
):
for gate_proj, up_proj, down_proj in zip(
layer.feedforward.experts.gate_projs,
layer.feedforward.experts.up_projs,
layer.feedforward.experts.down_projs,
):
yield gate_proj, up_proj, down_proj, FakeMLP(
gate_proj, up_proj, down_proj
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
for layer in layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(
original_apply_qkv, layer.self_attn
)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
for self_attn in find_self_attn_in_layer(layer):
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
if cfg.lora_mlp_kernel:
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp.up_proj
down_proj = layer.mlp.down_proj
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
layer.self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)
return model
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching
"""
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(self, gate_proj, up_proj, down_proj):
super().__init__()
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

View File

@@ -0,0 +1,101 @@
"""
Modified Llama-4 text experts modeling for linearized experts for improved LoRA support
"""
import sys
import torch
from torch import nn
from transformers import Llama4Config
from transformers.activations import ACT2FN
class Llama4TextExperts(nn.Module):
"""
Modified Llama-4 text experts modeling for linearized experts
"""
def __init__(self, config: Llama4Config):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
# Replace fused gate_up_proj with separate Linear modules
self.gate_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
self.up_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
# Replace down_proj Parameter with Linear modules
self.down_projs = nn.ModuleList(
[
nn.Linear(self.expert_dim, self.hidden_size, bias=False)
for _ in range(self.num_experts)
]
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward method using separate Linear layers for each expert.
Args:
hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)
The input should be organized by expert
Returns:
torch.Tensor: (num_experts * batch_size, hidden_size)
"""
# Reshape to separate by expert
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
# batch_size_per_expert = hidden_states.size(1)
# Initialize output tensor
next_states = torch.zeros_like(hidden_states)
# Process each expert separately
for i in range(self.num_experts):
# Get input for this expert
expert_input = hidden_states[
i
] # Shape: (batch_size_per_expert, hidden_size)
# Apply gate and up projections
gate = self.gate_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
up = self.up_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
# Apply activation and down projection
next_states[i] = self.down_projs[i](up * self.act_fn(gate))
# Flatten back to original shape
return next_states.view(-1, self.hidden_size)
def patch_llama4_linearized_modeling():
"""
Patch Llama4TextExperts to use separate Linear layers for each expert.
"""
from transformers.models.llama4 import modeling_llama4
modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
Llama4TextExperts,
)

View File

@@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_cu_seqlens_from_pos_ids(position_ids):
def get_cu_seqlens_from_pos_ids(
position_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)

View File

@@ -268,6 +268,7 @@ def get_processing_strategy(
)
if chat_template_type in [
"llama3_2_vision",
"llama4",
"llava",
"mistral_v7_tekken",
"pixtral",

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,6 @@ Data collators for axolotl to pad labels and position_ids for packed sequences.
includes logic for handling sequence parallelism collation.
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional, Union
@@ -13,46 +12,7 @@ import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample
packing.
"""
# Convert to tensor if not already
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()
# Process each sequence in the batch
for i in range(position_ids.shape[0]):
seq = position_ids[i]
# Find sample boundaries
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
# No need to adjust if there are no boundaries or this is a single sample
if not boundaries:
adjusted_pos_ids[i] = seq - start_idx
continue
# Adjust each segment separately
prev_boundary = 0
for boundary in boundaries:
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
prev_boundary = boundary
# Last segment
adjusted_pos_ids[i, prev_boundary:] -= start_idx
return adjusted_pos_ids
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
@dataclass
@@ -196,23 +156,20 @@ class DataCollatorForSeq2Seq:
Returns:
Sliced batch dictionary.
"""
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].shape[1]
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for ring attention calculation
update_ring_attn_params(batch=batch)
# Slice batch for sequence parallel processing
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size
if self.local_rank < self.local_world_size - 1
else seq_len
)
batch[key] = batch[key][:, start_idx:end_idx]
# Special handling for position_ids
if key == "position_ids" and self.local_rank > 0:
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
batch[key] = batch[key][:, start:end]
return batch

View File

@@ -96,20 +96,17 @@ def load_dataset_w_config(
pass
ds_from_cloud = False
storage_options = {}
storage_options: dict = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
raise ImportError("s3:// paths require s3fs to be installed") from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
storage_options = {"anon": False}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
@@ -125,28 +122,44 @@ def load_dataset_w_config(
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
elif (
config_dataset.path.startswith("adl://")
or config_dataset.path.startswith("abfs://")
or config_dataset.path.startswith("az://")
):
try:
import adlfs
except ImportError as exc:
raise ImportError(
"adl:// or abfs:// paths require adlfs to be installed"
) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# # Ensure you have the following environment variables set:
# # Gen 1
# storage_options = {
# "tenant_id": AZURE_STORAGE_TENANT_ID,
# "client_id": AZURE_STORAGE_CLIENT_ID,
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
# }
# Reads env
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
storage_options = {"anon": False}
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
elif config_dataset.path.startswith("oci://"):
try:
import ocifs
except ImportError as exc:
raise ImportError("oci:// paths require ocifs to be installed") from exc
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables
remote_file_system = ocifs.OCIFileSystem(**storage_options)
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True

View File

@@ -36,6 +36,7 @@ from transformers import (
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
@@ -76,6 +77,7 @@ LOG = logging.getLogger(__name__)
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
@@ -542,8 +544,20 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "llama4":
if self.cfg.llama4_linearized_experts:
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,

View File

@@ -245,6 +245,8 @@ class AxolotlInputConfig(
lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = None
fsdp: list[str] | None = None
fsdp_config: dict[str, Any] | None = None
@@ -1156,6 +1158,12 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if not info.data["micro_batch_size"] == 1:
raise ValueError(
"micro_batch_size must be set to 1 "
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
@@ -1165,6 +1173,18 @@ class AxolotlInputConfig(
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return value
@model_validator(mode="before")

View File

@@ -39,7 +39,6 @@ class SFTDataset(BaseModel):
input_format: str | None = None
name: str | None = None
ds_type: str | None = None
train_on_split: str | None = None
field: str | None = None
field_human: str | None = None
field_model: str | None = None

View File

@@ -26,8 +26,8 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name

View File

@@ -0,0 +1,87 @@
"""E2E tests for sequence parallelism"""
import os
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
def test_sequence_parallel_training(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 8,
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
)

View File

@@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
set_ring_attn_group,
)
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
from axolotl.utils.dict import DictDefault
@@ -48,33 +47,6 @@ def fixture_cfg():
return cfg
class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism."""
def test_adjust_position_ids_for_slice(self, partial_state):
"""Test position_ids adjustment for sequence slices."""
# Create sample position_ids with multiple sequences
position_ids = torch.tensor(
[
# First sequence with 2 samples
[0, 1, 2, 3, 4, 0, 1, 2, 3],
# Second sequence with 3 samples
[0, 1, 2, 0, 1, 2, 3, 0, 1],
]
)
# Adjust as if this was the second slice (start_idx = 4)
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
assert torch.all(adjusted[0] == expected_first_seq)
assert torch.all(adjusted[1] == expected_second_seq)
class TestRingAttention:
"""Tests for the ring attention functionality."""