Compare commits
10 Commits
sp-fix-mas
...
llama4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1aec93cf9e | ||
|
|
37630fc6ef | ||
|
|
4b28b2a0b4 | ||
|
|
b38f70e068 | ||
|
|
cf4c84e21d | ||
|
|
98d98ea1dd | ||
|
|
0cf42ab8a3 | ||
|
|
3d0ab75a0c | ||
|
|
d375be90ff | ||
|
|
98827e8f3b |
@@ -231,7 +231,6 @@ website:
|
||||
- docs/reward_modelling.qmd
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
- docs/dataset_loading.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ datasets:
|
||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
||||
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
split: train # Optional[str] name of dataset split to load from
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||
|
||||
@@ -165,9 +165,7 @@ datasets:
|
||||
content: value
|
||||
# ...
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages.
|
||||
# The format is {target_role: [source_roles]}. All source roles will be mapped to the target role.
|
||||
# The default is:
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant"]
|
||||
|
||||
@@ -13,13 +13,6 @@ As there are a lot of available options in Axolotl, this guide aims to provide a
|
||||
|
||||
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
This guide will mainly use JSONL as an introduction. Please refer to the [dataset loading docs](../dataset_loading.qmd) to understand how to load datasets from other sources.
|
||||
|
||||
For `pretraining_dataset:` specifically, please refer to the [Pre-training section](#pre-training).
|
||||
:::
|
||||
|
||||
## Pre-training
|
||||
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
---
|
||||
title: Dataset Loading
|
||||
description: Understanding how to load datasets from different sources
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 5
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Datasets can be loaded in a number of different ways depending on the how it is saved (the extension of the file) and where it is stored.
|
||||
|
||||
## Loading Datasets
|
||||
|
||||
We use the `datasets` library to load datasets and a mix of `load_dataset` and `load_from_disk` to load them.
|
||||
|
||||
You may recognize the similar named configs between `load_dataset` and the `datasets` section of the config file.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path:
|
||||
name:
|
||||
data_files:
|
||||
split:
|
||||
revision:
|
||||
trust_remote_code:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
Do not feel overwhelmed by the number of options here. A lot of them are optional. In fact, the most common config to use would be `path` and sometimes `data_files`.
|
||||
|
||||
:::
|
||||
|
||||
This matches the API of [`datasets.load_dataset`](https://github.com/huggingface/datasets/blob/0b5998ac62f08e358f8dcc17ec6e2f2a5e9450b6/src/datasets/load.py#L1838-L1858), so if you're familiar with that, you will feel right at home.
|
||||
|
||||
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
|
||||
|
||||
For full details on the config, see [config.qmd](config.qmd).
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
You can set multiple datasets in the config file by more than one entry under `datasets`.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: /path/to/your/dataset
|
||||
- path: /path/to/your/other/dataset
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
### Local dataset
|
||||
|
||||
#### Files
|
||||
|
||||
Usually, to load a JSON file, you would do something like this:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("json", data_files="data.json")
|
||||
```
|
||||
|
||||
Which translates to the following config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: json
|
||||
data_files: /path/to/your/file.jsonl
|
||||
```
|
||||
|
||||
However, to make things easier, we have added a few shortcuts for loading local dataset files.
|
||||
|
||||
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: /path/to/your/file.jsonl
|
||||
ds_type: json
|
||||
```
|
||||
|
||||
This works for CSV, JSON, Parquet, and Arrow files.
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If `path` points to a file and `ds_type` is not specified, we will automatically infer the dataset type from the file extension, so you could omit `ds_type` if you'd like.
|
||||
|
||||
:::
|
||||
|
||||
#### Directory
|
||||
|
||||
If you're loading a directory, you can point the `path` to the directory.
|
||||
|
||||
Then, you have two options:
|
||||
|
||||
##### Loading entire directory
|
||||
|
||||
You do not need any additional configs.
|
||||
|
||||
We will attempt to load in the following order:
|
||||
- datasets saved with `datasets.save_to_disk`
|
||||
- loading entire directory of files (such as with parquet/arrow files)
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: /path/to/your/directory
|
||||
```
|
||||
|
||||
##### Loading specific files in directory
|
||||
|
||||
Provide `data_files` with a list of files to load.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
# single file
|
||||
- path: /path/to/your/directory
|
||||
ds_type: csv
|
||||
data_files: file1.csv
|
||||
|
||||
# multiple files
|
||||
- path: /path/to/your/directory
|
||||
ds_type: json
|
||||
data_files:
|
||||
- file1.jsonl
|
||||
- file2.jsonl
|
||||
|
||||
# multiple files for parquet
|
||||
- path: /path/to/your/directory
|
||||
ds_type: parquet
|
||||
data_files:
|
||||
- file1.parquet
|
||||
- file2.parquet
|
||||
|
||||
```
|
||||
|
||||
### HuggingFace Hub
|
||||
|
||||
The method you use to load the dataset depends on how the dataset was created, whether a folder was uploaded directly or a HuggingFace Dataset was pushed.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
If you're using a private dataset, you will need to enable the `hf_use_auth_token` flag in the root-level of the config file.
|
||||
|
||||
:::
|
||||
|
||||
#### Folder uploaded
|
||||
|
||||
This would mean that the dataset is a single file or file(s) uploaded to the Hub.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: org/dataset-name
|
||||
data_files:
|
||||
- file1.jsonl
|
||||
- file2.jsonl
|
||||
```
|
||||
|
||||
#### HuggingFace Dataset
|
||||
|
||||
This means that the dataset is created as a HuggingFace Dataset and pushed to the Hub via `datasets.push_to_hub`.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: org/dataset-name
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
There are some other configs which may be required like `name`, `split`, `revision`, `trust_remote_code`, etc depending on the dataset.
|
||||
|
||||
:::
|
||||
|
||||
### Remote Filesystems
|
||||
|
||||
Via the `storage_options` config under `load_dataset`, you can load datasets from remote filesystems like S3, GCS, Azure, and OCI.
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
This is currently experimental. Please let us know if you run into any issues!
|
||||
|
||||
:::
|
||||
|
||||
The only difference between the providers is that you need to prepend the path with the respective protocols.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
# Single file
|
||||
- path: s3://bucket-name/path/to/your/file.jsonl
|
||||
|
||||
# Directory
|
||||
- path: s3://bucket-name/path/to/your/directory
|
||||
```
|
||||
|
||||
For directory, we load via `load_from_disk`.
|
||||
|
||||
#### S3
|
||||
|
||||
Prepend the path with `s3://`.
|
||||
|
||||
The credentials are pulled in the following order:
|
||||
|
||||
- `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN` environment variables
|
||||
- from the `~/.aws/credentials` file
|
||||
- for nodes on EC2, the IAM metadata provider
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We assume you have credentials setup and not using anonymous access. If you want to use anonymous access, let us know! We may have to open a config option for this.
|
||||
|
||||
:::
|
||||
|
||||
Other environment variables that can be set can be found in [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables)
|
||||
|
||||
#### GCS
|
||||
|
||||
Prepend the path with `gs://` or `gcs://`.
|
||||
|
||||
The credentials are loaded in the following order:
|
||||
|
||||
- gcloud credentials
|
||||
- for nodes on GCP, the google metadata service
|
||||
- anonymous access
|
||||
|
||||
#### Azure
|
||||
|
||||
##### Gen 1
|
||||
|
||||
Prepend the path with `adl://`.
|
||||
|
||||
Ensure you have the following environment variables set:
|
||||
|
||||
- `AZURE_STORAGE_TENANT_ID`
|
||||
- `AZURE_STORAGE_CLIENT_ID`
|
||||
- `AZURE_STORAGE_CLIENT_SECRET`
|
||||
|
||||
##### Gen 2
|
||||
|
||||
Prepend the path with `abfs://` or `az://`.
|
||||
|
||||
Ensure you have the following environment variables set:
|
||||
|
||||
- `AZURE_STORAGE_ACCOUNT_NAME`
|
||||
- `AZURE_STORAGE_ACCOUNT_KEY`
|
||||
|
||||
Other environment variables that can be set can be found in [adlfs docs](https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials)
|
||||
|
||||
#### OCI
|
||||
|
||||
Prepend the path with `oci://`.
|
||||
|
||||
It would attempt to read in the following order:
|
||||
|
||||
- `OCIFS_IAM_TYPE`, `OCIFS_CONFIG_LOCATION`, and `OCIFS_CONFIG_PROFILE` environment variables
|
||||
- when on OCI resource, resource principal
|
||||
|
||||
Other environment variables:
|
||||
|
||||
- `OCI_REGION_METADATA`
|
||||
|
||||
Please see the [ocifs docs](https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables).
|
||||
|
||||
### HTTPS
|
||||
|
||||
The path should start with `https://`.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: https://path/to/your/dataset/file.jsonl
|
||||
```
|
||||
|
||||
This must be publically accessible.
|
||||
|
||||
## Next steps
|
||||
|
||||
Now that you know how to load datasets, you can learn more on how to load your specific dataset format into your target output format [dataset formats docs](dataset-formats).
|
||||
@@ -9,7 +9,6 @@ format:
|
||||
## Supported Models
|
||||
|
||||
- [Mllama](#sec-mllama)
|
||||
- [Llama4](#sec-llama4)
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
@@ -64,14 +63,6 @@ base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
chat_template: llama3_2_vision
|
||||
```
|
||||
|
||||
### Llama4 {#sec-llama4}
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
|
||||
chat_template: llama4
|
||||
```
|
||||
|
||||
### Pixtral {#sec-pixtral}
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -49,8 +49,7 @@ python-dotenv==1.0.1
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
adlfs>=2024.5.0
|
||||
ocifs==1.3.2
|
||||
# adlfs
|
||||
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -235,9 +235,6 @@ class AxolotlTrainer(
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
|
||||
@@ -1,22 +1,34 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
Mixin class for sequence parallelism support in trainers.
|
||||
|
||||
This mixin provides functionality for handling sequence parallelism,
|
||||
specifically for creating appropriate data samplers.
|
||||
including creating appropriate samplers, managing data partitioning,
|
||||
and updating ring flash attention parameters during training.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -87,3 +99,84 @@ class SequenceParallelMixin:
|
||||
return self._create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||
`input_ids`.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
|
||||
@@ -6,12 +6,10 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
@@ -100,27 +98,3 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
|
||||
|
||||
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||
value to the substituted `ring_flash_attn`.
|
||||
|
||||
Args:
|
||||
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
||||
data; if not, we compute it.
|
||||
"""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
position_ids = batch.get("position_ids")
|
||||
if position_ids is None:
|
||||
seq_len = input_ids.shape[1]
|
||||
position_ids = torch.arange(
|
||||
0, seq_len, dtype=torch.long, device=input_ids.device
|
||||
).unsqueeze(0)
|
||||
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
@@ -96,9 +96,7 @@ def get_cu_seqlens(attn_mask):
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def get_cu_seqlens_from_pos_ids(
|
||||
position_ids: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
@@ -268,7 +268,6 @@ def get_processing_strategy(
|
||||
)
|
||||
if chat_template_type in [
|
||||
"llama3_2_vision",
|
||||
"llama4",
|
||||
"llava",
|
||||
"mistral_v7_tekken",
|
||||
"pixtral",
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -3,6 +3,7 @@ Data collators for axolotl to pad labels and position_ids for packed sequences.
|
||||
includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -12,7 +13,46 @@ import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample
|
||||
packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
# Process each sequence in the batch
|
||||
for i in range(position_ids.shape[0]):
|
||||
seq = position_ids[i]
|
||||
|
||||
# Find sample boundaries
|
||||
boundaries = []
|
||||
for j in range(1, len(seq)):
|
||||
if seq[j] < seq[j - 1]:
|
||||
boundaries.append(j)
|
||||
|
||||
# No need to adjust if there are no boundaries or this is a single sample
|
||||
if not boundaries:
|
||||
adjusted_pos_ids[i] = seq - start_idx
|
||||
continue
|
||||
|
||||
# Adjust each segment separately
|
||||
prev_boundary = 0
|
||||
for boundary in boundaries:
|
||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||
prev_boundary = boundary
|
||||
|
||||
# Last segment
|
||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||
|
||||
return adjusted_pos_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -156,20 +196,23 @@ class DataCollatorForSeq2Seq:
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
total_seq_len = batch["input_ids"].shape[1]
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
|
||||
# Update params for ring attention calculation
|
||||
update_ring_attn_params(batch=batch)
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
batch[key] = batch[key][:, start:end]
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
# Special handling for position_ids
|
||||
if key == "position_ids" and self.local_rank > 0:
|
||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -96,17 +96,20 @@ def load_dataset_w_config(
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
storage_options: dict = {}
|
||||
storage_options = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
try:
|
||||
import aiobotocore.session # type: ignore
|
||||
import s3fs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError("s3:// paths require s3fs to be installed") from exc
|
||||
raise ImportError(
|
||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||
) from exc
|
||||
|
||||
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider
|
||||
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
|
||||
storage_options = {"anon": False}
|
||||
# Takes credentials from ~/.aws/credentials for default profile
|
||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||
storage_options = {"session": s3_session}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
|
||||
"gcs://"
|
||||
@@ -122,44 +125,28 @@ def load_dataset_w_config(
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
elif (
|
||||
config_dataset.path.startswith("adl://")
|
||||
or config_dataset.path.startswith("abfs://")
|
||||
or config_dataset.path.startswith("az://")
|
||||
):
|
||||
try:
|
||||
import adlfs
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"adl:// or abfs:// paths require adlfs to be installed"
|
||||
) from exc
|
||||
# TODO: Figure out how to get auth creds passed
|
||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||
# try:
|
||||
# import adlfs
|
||||
# except ImportError as exc:
|
||||
# raise ImportError(
|
||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||
# ) from exc
|
||||
|
||||
# # Ensure you have the following environment variables set:
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": AZURE_STORAGE_TENANT_ID,
|
||||
# "client_id": AZURE_STORAGE_CLIENT_ID,
|
||||
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
|
||||
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# Reads env
|
||||
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
|
||||
storage_options = {"anon": False}
|
||||
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("oci://"):
|
||||
try:
|
||||
import ocifs
|
||||
except ImportError as exc:
|
||||
raise ImportError("oci:// paths require ocifs to be installed") from exc
|
||||
|
||||
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables
|
||||
remote_file_system = ocifs.OCIFileSystem(**storage_options)
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": TENANT_ID,
|
||||
# "client_id": CLIENT_ID,
|
||||
# "client_secret": CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": ACCOUNT_NAME,
|
||||
# "account_key": ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
try:
|
||||
if remote_file_system and remote_file_system.exists(config_dataset.path):
|
||||
ds_from_cloud = True
|
||||
|
||||
@@ -36,7 +36,6 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
Gemma3ForConditionalGeneration,
|
||||
GPTQConfig,
|
||||
Llama4ForConditionalGeneration,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
@@ -77,7 +76,6 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
"llama4": Llama4ForConditionalGeneration,
|
||||
"llava": LlavaForConditionalGeneration,
|
||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||
|
||||
@@ -1156,12 +1156,6 @@ class AxolotlInputConfig(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if not info.data["micro_batch_size"] == 1:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 "
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
@@ -1171,18 +1165,6 @@ class AxolotlInputConfig(
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={value}. Please note that logged losses may "
|
||||
"differ slightly to the non-SP losses due to transformers Trainer "
|
||||
"implementation details. Please see "
|
||||
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -39,6 +39,7 @@ class SFTDataset(BaseModel):
|
||||
input_format: str | None = None
|
||||
name: str | None = None
|
||||
ds_type: str | None = None
|
||||
train_on_split: str | None = None
|
||||
field: str | None = None
|
||||
field_human: str | None = None
|
||||
field_model: str | None = None
|
||||
|
||||
@@ -26,8 +26,8 @@ class ChatTemplate(str, Enum):
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
llama4 = "llama4" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
"""E2E tests for sequence parallelism"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_tensorboard
|
||||
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
|
||||
def test_sequence_parallel_training(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 8,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"loss_watchdog_threshold": 5.0,
|
||||
"loss_watchdog_patience": 3,
|
||||
"bf16": "auto",
|
||||
"warmup_steps": 1,
|
||||
"saves_per_epoch": 1,
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
||||
)
|
||||
@@ -12,6 +12,7 @@ from axolotl.monkeypatch.attention.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -47,6 +48,33 @@ def fixture_cfg():
|
||||
return cfg
|
||||
|
||||
|
||||
class TestSequenceParallelHelpers:
|
||||
"""Test helper functions used in sequence parallelism."""
|
||||
|
||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||
"""Test position_ids adjustment for sequence slices."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
# First sequence with 2 samples
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||
# Second sequence with 3 samples
|
||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Adjust as if this was the second slice (start_idx = 4)
|
||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
||||
|
||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
||||
|
||||
assert torch.all(adjusted[0] == expected_first_seq)
|
||||
assert torch.all(adjusted[1] == expected_second_seq)
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user