Compare commits

..

7 Commits

Author SHA1 Message Date
Dan Saunders
cbcc795bb3 commenting out unused 2025-06-16 01:53:13 +00:00
Dan Saunders
e34b6f4dfe temp: trying another approach 2025-06-15 21:32:10 +00:00
Dan Saunders
f8f87321bd progress 2025-06-14 17:40:21 +00:00
Dan Saunders
7a88de4fa8 finish basic impl; change naming from SP -> CP to match torch 2025-06-13 09:51:06 -04:00
Dan Saunders
aced809989 progress (messy :O) 2025-06-12 18:54:41 +00:00
Dan Saunders
ae73123eae progress; move validation to pydantic model config 2025-06-07 06:58:59 +00:00
Dan Saunders
10d1e44943 SDPA context parallel 2025-06-06 00:34:12 +00:00
33 changed files with 874 additions and 1094 deletions

View File

@@ -8,7 +8,7 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/core/trainers/mixins/context_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:

View File

@@ -75,7 +75,7 @@ quartodoc:
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:
- utils.ctx_managers.sequence_parallel
- utils.ctx_managers.context_parallel
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
@@ -274,7 +274,7 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/context_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -764,13 +764,13 @@ ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Context parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# See https://docs.axolotl.ai/docs/context_parallelism.html for more details.
context_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1

View File

@@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- Context parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -80,14 +80,14 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
## Context parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
We support context parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
See our [dedicated guide](context_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -1,16 +1,16 @@
---
title: Sequence Parallelism
title: Context Parallelism
description: Train with long sequences split across multiple GPUs.
---
Sequence parallelism is a technique that splits sequences across multiple GPUs,
Context parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
## When to Use Context Parallelism
Use sequence parallelism when:
Use context parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
@@ -18,11 +18,11 @@ Use sequence parallelism when:
## Configuration
To enable sequence parallelism, add the following to your configuration file:
To enable context parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,23 +30,23 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
When context parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
To use context parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -79,22 +79,22 @@ ring_attn_func:
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
## Sample Packing with Context Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
2. The packed sequences are then divided across GPUs in the context parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and no context parallelism: 8 different batches processed per step
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -20,7 +20,6 @@ datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.1
hf_xet==1.1.2
mistral-common[hf-hub]==1.6.0
optimum==1.16.2
hf_transfer

View File

@@ -73,7 +73,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
context_parallel_degree=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,

View File

@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.

View File

@@ -54,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
context_parallel=self.cfg.context_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

View File

@@ -5,7 +5,7 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (

View File

@@ -7,11 +7,13 @@ from __future__ import annotations
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
import datasets
import torch
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# SPDA device mesh init
import torch.distributed as dist
world_size = dist.get_world_size()
mesh_shape = (
world_size // 2,
2,
)
self.world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
ctx_manager = get_context_parallel_manager(
world_mesh=self.world_mesh,
model=model,
)
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
with ctx_manager(list(to_shard.values())):
super().training_step(model, inputs, num_items_in_batch)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access

View File

@@ -8,7 +8,7 @@ from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOContextParallelTrainer,
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
@@ -23,10 +23,10 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
cls, context_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]:
if context_parallel:
return AxolotlGRPOContextParallelTrainer
return AxolotlGRPOTrainer
@classmethod
@@ -69,8 +69,8 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if cfg.context_parallel_degree > 1:
grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None
context_parallel_degree: int | None = None

View File

@@ -1,7 +1,7 @@
"""Repeat random sampler (similar to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequence parallel group.
context parallelism functionality; i.e., duplicating data across ranks in the same
context parallel group.
"""
from typing import Iterator, Sized
@@ -10,26 +10,26 @@ import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""Sampler for GRPO training with sequence parallelism.
class ContextParallelRepeatRandomSampler(Sampler):
"""Sampler for GRPO training with context parallelism.
This sampler ensures:
- Ranks in the same sequence parallel (SP) group receive identical data.
- Ranks in the same context parallel (SP) group receive identical data.
- Each index is repeated multiple times for sampling different completions.
- Entire batches are repeated for reuse in multiple updates.
- Data is properly distributed across SP groups.
- Data is properly distributed across CP groups.
In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
In the table below, the values represent dataset indices. Each CP group has
`context_parallel_degree = 2` GPUs working together on the same data. There are 2
CP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups
Context Parallel Groups
| SP0 | SP1 |
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
global_step step <---> mini_repeat_count=3
<----------> batch_size=2 per SP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
<----------> batch_size=2 per CP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU
|
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
context_parallel_degree: Number of ranks in a context parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
context_parallel_degree: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
@@ -76,16 +76,16 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Context parallelism parameters
self.context_parallel_degree = context_parallel_degree
self.num_sp_groups = world_size // context_parallel_degree
self.sp_group_id = rank // context_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
# Calculate effective number of samples per CP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
@@ -125,8 +125,8 @@ class SequenceParallelRepeatRandomSampler(Sampler):
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
# Subsample based on CP group ID
# Each CP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size

View File

@@ -1,9 +1,8 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
"""Axolotl GRPO trainers (with and without context parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
import datasets
@@ -42,7 +41,7 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
@@ -59,45 +58,9 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for context parallelism handling"""
def __init__(
self,
@@ -134,11 +97,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
)
# Get number of SP groups (number of processes divided by SP degree)
# Get number of CP groups (number of processes divided by CP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
num_sp_groups = num_processes // self.args.context_parallel_degree
# Calculate batch size per SP group (not per process)
# Calculate batch size per CP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
possible_values = [
n_gen
@@ -148,7 +111,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"The batch size per SP group ({num_sp_groups} x "
f"The batch size per CP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of "
@@ -156,7 +119,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group
# If context parallelism is enabled, calculate batch size per CP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [
n_gen
@@ -166,8 +129,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"With context parallelism (degree {self.args.context_parallel_degree}), "
f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
f"the valid values for the number of generations are: {possible_values}."
@@ -180,7 +143,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.local_world_size = 1
def train(self, *args, **kwargs):
# Initialize the SP group
# Initialize the CP group
self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
@@ -196,16 +159,16 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
return ContextParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
world_size=self.world_size,
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
// self.args.context_parallel_degree,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
context_parallel_degree=self.args.context_parallel_degree,
shuffle=True,
seed=self.args.seed,
drop_last=True,
@@ -263,11 +226,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# Return unprepared dataloader if using context parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
@@ -340,21 +303,21 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
# Calculate sequence parallel group information
if self.args.context_parallel_degree > 1:
# Calculate context parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
context_parallel_degree = self.args.context_parallel_degree
num_sp_groups = world_size // context_parallel_degree
# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
# Since processes in the same CP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each CP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
# Get the first process from each CP group (typically the group leader)
group_leader_rank = sp_group_id * context_parallel_degree
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
# Extract prompts from this CP group, accounting for num_generations duplicates
# We only need prompts from one rank in each CP group
group_prompts = all_prompts_text[
group_leader_rank
* len(prompts_text) : (group_leader_rank + 1)
@@ -367,7 +330,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
:: self.num_generations * self.args.context_parallel_degree
]
with profiling_context(self, "vLLM.generate"):
@@ -384,28 +347,28 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
len(all_prompts_text) // self.args.context_parallel_degree
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
# Determine the appropriate slice based on context parallelism
if self.args.context_parallel_degree > 1:
# Calculate CP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
# Calculate the start index for this CP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
# All ranks in the same CP group get the same data slice
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts),
)
completion_ids = completion_ids[process_slice]
else:
# Original behavior for non-sequence parallel case
# Original behavior for non-context parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
@@ -615,20 +578,20 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
if self.args.context_parallel_degree > 1:
# Calculate CP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
# Calculate the start index for this CP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
# All ranks in the same CP group get the same data slice
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts),
)
else:
# Original behavior for non-sequence parallel case
# Original behavior for non-context parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),

View File

@@ -64,10 +64,6 @@ class TokenizedPromptDataset(Dataset):
desc="Strategy Filtering Rows",
)
import ipdb
ipdb.set_trace()
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -2,16 +2,8 @@
import json
import os
from typing import Any, Dict, List, Optional, Union
import torch
import transformers
from huggingface_hub import hf_hub_download
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer,
)
from transformers import (
AddedToken,
AutoTokenizer,
@@ -31,622 +23,239 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
# Constants
LLAMA_TOKENIZER_CLASSES = {
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
}
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
GPTNEOX_PAD_TOKEN = "[PAD]"
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
class MistralTokenizerWrapper:
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
) -> str:
"""
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
Modify tokenizer files to replace added_tokens strings, save to output directory,
and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens
already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
self.mistral_tokenizer = mistral_tokenizer
self.model_id = model_id
self._system_prompt = None
self.padding_side = "right" # Default padding side
self.chat_template = None
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Cache token IDs by inspecting the actual tokenizer
self._token_ids = self._discover_token_ids()
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Try to load system prompt if available
try:
self._system_prompt = self._load_system_prompt(
model_id, "SYSTEM_PROMPT.txt"
)
except Exception as e:
LOG.debug(f"Could not load system prompt: {e}")
def _discover_token_ids(self) -> Dict[str, int]:
"""Discover the actual token IDs used by this Mistral tokenizer."""
token_ids = {}
try:
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
# Get BOS token ID from instruct_tokenizer
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
# Get token IDs from the underlying Tekkenizer
if hasattr(instruct_tokenizer, "tokenizer"):
tekkenizer = instruct_tokenizer.tokenizer
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
if hasattr(tekkenizer, "bos_id"):
token_ids["bos_token_id"] = tekkenizer.bos_id
# Get vocab size to help find EOS token
vocab_size = getattr(tekkenizer, "_vocab_size", None)
# Check special tokens
if hasattr(tekkenizer, "_all_special_tokens"):
special_tokens = tekkenizer._all_special_tokens
keys = (
list(special_tokens.keys())
if hasattr(special_tokens, "keys")
else special_tokens
)
LOG.debug(f"Special tokens available: {keys}")
# Try to find EOS token in special tokens
if hasattr(special_tokens, "get"):
# Common EOS token patterns
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
if eos_pattern in special_tokens:
token_ids["eos_token_id"] = special_tokens[
eos_pattern
]
break
# Check special tokens reverse vocab
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
# Look for common special token IDs
for token_id, token_str in reverse_vocab.items():
if token_str in ["</s>", "<|endoftext|>"]:
token_ids["eos_token_id"] = token_id
elif token_str in ["<unk>", "<UNK>"]:
token_ids["unk_token_id"] = token_id
# If we have vocab_size, EOS is often vocab_size - 1 or similar
if "eos_token_id" not in token_ids and vocab_size:
# Common patterns: EOS could be 2, vocab_size-1, or other values
# Let's try a safer approach by checking what tokens decode to
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
try:
# Try to decode and see if it looks like EOS
decoded = tekkenizer.decode([candidate_id])
if decoded in ["</s>", "<|endoftext|>", ""]:
token_ids["eos_token_id"] = candidate_id
break
except Exception:
continue
except Exception as e:
LOG.debug(f"Could not discover token IDs: {e}")
# Set reasonable defaults for any missing token IDs
token_ids.setdefault("bos_token_id", 1)
token_ids.setdefault("eos_token_id", 2)
token_ids.setdefault("unk_token_id", 0)
token_ids.setdefault(
"pad_token_id", token_ids["eos_token_id"]
) # Use EOS as pad
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
return token_ids
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
"""Load system prompt from HuggingFace Hub"""
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "r") as file:
return file.read()
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
"""Encode text to token IDs"""
if isinstance(text, str):
# For simple string encoding, create a user message
messages = []
if self._system_prompt and add_special_tokens:
messages.append(SystemMessage(content=self._system_prompt))
messages.append(UserMessage(content=text))
tokenized = self.mistral_tokenizer.encode_chat_completion(
ChatCompletionRequest(messages=messages)
)
return tokenized.tokens
else:
raise ValueError("MistralTokenizer wrapper only supports string input")
def decode(
self,
token_ids: Union[List[int], torch.Tensor],
skip_special_tokens: bool = True,
) -> str:
"""Decode token IDs to text"""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
return self.mistral_tokenizer.decode(token_ids)
def __call__(self, text: str, **kwargs):
"""Make the tokenizer callable like HF tokenizers"""
tokens = self.encode(text, **kwargs)
return {"input_ids": torch.tensor([tokens])}
@property
def eos_token_id(self):
return self._token_ids["eos_token_id"]
@property
def bos_token_id(self):
return self._token_ids["bos_token_id"]
@property
def pad_token_id(self):
return self._token_ids["pad_token_id"]
@property
def unk_token_id(self):
return self._token_ids["unk_token_id"]
@property
def eos_token(self):
return "</s>" # Standard Mistral EOS token
@property
def bos_token(self):
return "<s>" # Standard Mistral BOS token
@property
def pad_token(self):
return self.eos_token # Use EOS as pad token
@property
def unk_token(self):
return "<unk>" # Standard UNK token
@property
def __class__(self):
# Create a mock class for compatibility checks
class MistralTokenizerWrapperClass:
__name__ = "MistralTokenizerWrapper"
return MistralTokenizerWrapperClass
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
LOG.warning(
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
)
return 0
def add_tokens(self, tokens) -> int:
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
LOG.warning(
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
)
return 0
class TokenizerFileModifier:
"""Handles modification of tokenizer files for token overrides."""
def __init__(
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
):
self.tokenizer_path = tokenizer_path
self.token_mappings = token_mappings
self.output_dir = output_dir
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
def modify_and_save(self) -> str:
"""Modify tokenizer files and return path to modified tokenizer."""
os.makedirs(self.tokenizer_dir, exist_ok=True)
if is_local_main_process():
self._perform_modifications()
barrier()
return self.tokenizer_dir
def _perform_modifications(self):
"""Perform the actual file modifications."""
# Load and save tokenizer to output directory
temp_tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_path, use_fast=True
)
temp_tokenizer.save_pretrained(self.tokenizer_dir)
# Convert token mappings to proper format
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value
for token_id, new_value in self.token_mappings.items()
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# Update both tokenizer files
self._update_tokenizer_config(token_id_mappings)
self._update_tokenizer_json(token_id_mappings)
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
"""Update tokenizer_config.json with new token mappings."""
config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
if not os.path.exists(config_path):
return
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
if "added_tokens_decoder" in config_data:
self._update_added_tokens_decoder(config_data, token_id_mappings)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
for token_id, new_value in token_id_mappings.items():
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
if entry_id == token_id:
del tokenizer_data["model"]["vocab"][entry_val]
tokenizer_data["model"]["vocab"][new_value] = token_id
break
def _update_added_tokens_decoder(
self, config_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update the added_tokens_decoder section."""
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
"""Update tokenizer.json with new token mappings."""
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
if not os.path.exists(tokenizer_json_path):
return
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
def _update_added_tokens_list(
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update the added_tokens list in tokenizer.json."""
if "added_tokens" not in tokenizer_data:
return
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
raise ValueError(f"Token ID {token_id} not found in added_tokens")
def _update_vocab_mappings(
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
):
"""Update vocab mappings in tokenizer.json."""
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
return
vocab = tokenizer_data["model"]["vocab"]
for token_id, new_value in token_id_mappings.items():
# Find and update the vocab entry
for entry_val, entry_id in list(vocab.items()):
if entry_id == token_id:
del vocab[entry_val]
vocab[new_value] = token_id
break
class TokenizerConfiguration:
"""Handles tokenizer configuration and initialization."""
def __init__(self, cfg):
self.cfg = cfg
self.model_config = load_model_config(cfg)
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
"""Load Mistral tokenizer from model configuration."""
# Instantiate Mistral tokenizer
model_id = self.cfg.base_model
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
# Wrap it for compatibility
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
return tokenizer
def get_tokenizer_class(self):
"""Get the appropriate tokenizer class."""
if self.cfg.tokenizer_type:
return getattr(transformers, self.cfg.tokenizer_type)
return AutoTokenizer
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
"""Build tokenizer initialization kwargs."""
kwargs = {}
if self.cfg.tokenizer_legacy is not None:
kwargs["legacy"] = self.cfg.tokenizer_legacy
return kwargs
def get_tokenizer_path(self) -> str:
"""Get the tokenizer path, applying overrides if needed."""
tokenizer_path = self.cfg.tokenizer_config
if self.cfg.added_tokens_overrides:
modifier = TokenizerFileModifier(
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
)
tokenizer_path = modifier.modify_and_save()
return tokenizer_path
def should_use_fast_tokenizer(self) -> bool:
"""Determine if fast tokenizer should be used."""
return (
self.cfg.tokenizer_use_fast
if self.cfg.tokenizer_use_fast is not None
else True
)
class TokenizerPostProcessor:
"""Handles post-processing configuration of loaded tokenizers."""
def __init__(self, tokenizer, cfg):
self.tokenizer = tokenizer
self.cfg = cfg
self.model_config = load_model_config(cfg)
def apply_all_configurations(self):
"""Apply all post-processing configurations to the tokenizer."""
# Skip most configurations for Mistral wrapper
if isinstance(self.tokenizer, MistralTokenizerWrapper):
self._configure_mistral_wrapper()
return
self._configure_padding_token()
self._configure_gptneox_settings()
self._configure_mistral_padding()
self._configure_qwen_tokens()
self._add_special_tokens()
self._add_regular_tokens()
self._configure_chat_template()
def _configure_mistral_wrapper(self):
"""Apply limited configurations for Mistral wrapper."""
# Set padding side if needed
if (
self.cfg.is_mistral_derived_model
and self.cfg.flash_attention
and not self.cfg.sample_packing
):
self.tokenizer.padding_side = "left"
# Configure chat template for Mistral
self._configure_chat_template()
def _configure_padding_token(self):
"""Configure padding token for Llama-based tokenizers."""
if (
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
and hasattr(self.tokenizer, "pad_token")
and not self.tokenizer.pad_token
):
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
def _configure_gptneox_settings(self):
"""Configure GPTNeoX-specific settings."""
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def _configure_mistral_padding(self):
"""Configure left padding for Mistral models with Flash Attention."""
if (
self.cfg.is_mistral_derived_model
and self.cfg.flash_attention
and not self.cfg.sample_packing
):
self.tokenizer.padding_side = "left"
def _configure_qwen_tokens(self):
"""Configure special tokens for Qwen models."""
if not self.cfg.is_qwen_derived_model:
return
# Set token IDs
token_id_attributes = [
"bos_token_id",
"eos_token_id",
"pad_token_id",
"unk_token_id",
]
for attr_name in token_id_attributes:
if getattr(self.tokenizer, attr_name) is None:
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
# Set token strings
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_name_attributes:
if getattr(self.tokenizer, attr_name) is None:
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
def _add_special_tokens(self):
"""Add special tokens from configuration."""
if not self.cfg.special_tokens:
return
special_tokens_dict = self.cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens_dict.pop(
"additional_special_tokens", None
)
self._validate_and_add_special_tokens(special_tokens_dict)
self._update_post_processor_if_needed(special_tokens_dict)
self._add_additional_special_tokens_if_present(additional_special_tokens)
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
"""Validate special tokens for adapter training and add them."""
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
for key, value in special_tokens.items():
self._validate_token_for_adapter(key, value, lora_modules_to_save)
self.tokenizer.add_special_tokens(
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
)
def _validate_token_for_adapter(
self, key: str, value: str, lora_modules_to_save: List[str]
):
"""Validate a single token for adapter training requirements."""
if not self._should_validate_token_for_adapter(
key, value, lora_modules_to_save
):
return
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
raise ValueError(
f"Please set lora_modules_to_save to [{modules_str}] "
f"when using an adapter and changing the special tokens."
)
def _should_validate_token_for_adapter(
self, key: str, value: str, lora_modules_to_save: List[str]
) -> bool:
"""Check if token should be validated for adapter configuration."""
if key == "pad_token" or not self.cfg.adapter:
return False
current_token = getattr(self.tokenizer, key)
token_changed = current_token is None or current_token != value
token_is_multi_char = (
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
)
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
)
return token_changed and token_is_multi_char and lora_modules_missing
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
has_bos_and_eos = (
"bos_token" in special_tokens and "eos_token" in special_tokens
)
is_fast_llama = (
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
)
if is_fast_llama and has_bos_and_eos:
self.tokenizer.update_post_processor()
def _add_additional_special_tokens_if_present(
self, additional_special_tokens: Optional[List[str]]
):
"""Add additional special tokens if they exist."""
if additional_special_tokens is not None:
self.tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
def _add_regular_tokens(self):
"""Add regular (non-special) tokens from configuration."""
if self.cfg.tokens:
self.tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in self.cfg.tokens
]
)
def _configure_chat_template(self):
"""Configure chat template if specified."""
if not self.cfg.chat_template:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return
chat_template_string = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self._should_replace_default_system_message():
chat_template_string = chat_template_string.replace(
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
)
self.tokenizer.chat_template = chat_template_string
def _should_replace_default_system_message(self) -> bool:
"""Check if default system message should be replaced."""
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config.
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
This function handles the complete tokenizer loading pipeline:
- Check if Mistral tokenizer should be used
- Configure tokenizer parameters and get the appropriate class
- Handle token file modifications if needed
- Initialize the tokenizer with the correct parameters
- Apply all post-processing configurations (padding, special tokens, etc.)
- Set up chat templates and logging
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
Returns:
Fully configured tokenizer instance.
"""
# Configure tokenizer parameters
config = TokenizerConfiguration(cfg)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Check if we should use Mistral tokenizer
try:
tokenizer = config.load_mistral_tokenizer()
except:
# Standard tokenizer loading
tokenizer_cls = config.get_tokenizer_class()
tokenizer_path = config.get_tokenizer_path()
use_fast = config.should_use_fast_tokenizer()
tokenizer_kwargs = config.get_tokenizer_kwargs()
# Initialize the tokenizer
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
# Apply all post-processing configurations
post_processor = TokenizerPostProcessor(tokenizer, cfg)
post_processor.apply_all_configurations()
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -654,4 +263,19 @@ def load_tokenizer(cfg):
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer

View File

@@ -2,10 +2,10 @@
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
their context parallel version of Flash Attention 2.
We also provide some patches for accelerate functions to prepare the dataloader for
sequence parallelism training.
context parallelism training.
"""
import inspect
@@ -13,9 +13,9 @@ import inspect
import accelerate
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
@@ -63,15 +63,15 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
def register_ring_attn(
sequence_parallel_degree: int,
context_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
context_parallel_degree: Context parallelism factor.
heads_k_stride: Context parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
@@ -80,28 +80,18 @@ def register_ring_attn(
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
LOG.info(
"Enabling ring attention context parallelism: "
f"each sequence will be processed across {context_parallel_degree} GPUs"
)
# Assign ranks to sequence parallel groups
# Assign ranks to context parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
for i in range(world_size // context_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
i * context_parallel_degree,
(i + 1) * context_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
@@ -113,9 +103,7 @@ def register_ring_attn(
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
LOG.info(f"Context parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
@@ -150,7 +138,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the CP degree.
Raies:
RuntimeError: If source code to patch does not exist.
@@ -176,15 +164,15 @@ def patch_prepare_data_loader():
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
LOG.info("Patched accelerate.data_loader.prepare_data_loader for CP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int):
def patch_prepare_device_mesh(context_parallel_degree: int):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
that includes context parallelism with the specified degree.
Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
context_parallel_degree (int): The degree of context parallelism to use.
"""
def _prepare_device_mesh(self):
@@ -199,11 +187,11 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
# Create device mesh with context parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
world_size // context_parallel_degree,
context_parallel_degree,
)
device_ids = list(range(world_size))
@@ -221,5 +209,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
f"with context_parallel_degree={context_parallel_degree}"
)

View File

@@ -67,10 +67,6 @@ class PromptTokenizingStrategy(abc.ABC):
LOG.warning("Empty text requested for tokenization.")
return empty
import ipdb
ipdb.set_trace()
result = self.tokenizer(
prompt,
truncation=True,

View File

@@ -25,14 +25,13 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.ctx_managers import ContextParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
@@ -202,15 +201,20 @@ def execute_training(
)
)
if cfg.sequence_parallel_degree > 1:
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
# Models to enter context parallel manager for
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
# Attention backend
backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
stack.enter_context(
SequenceParallelContextManager(
ContextParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
backend=backend,
context_parallel_degree=cfg.context_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
@@ -224,7 +228,7 @@ def execute_training(
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
safe_serialization: bool,
):
"""
@@ -375,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
peft_config: PeftConfig | None,
processor: ProcessorMixin | None,
):
@@ -429,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
@@ -472,7 +476,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,

View File

@@ -1,6 +1,5 @@
"""Init for context manager submodule"""
"""Init for context manager submodule."""
# pylint: disable=unused-import
# flake8: noqa
from .context_parallel.manager import ContextParallelContextManager
from .sequence_parallel import SequenceParallelContextManager
__all__ = ["ContextParallelContextManager"]

View File

@@ -0,0 +1,146 @@
# BSD 3-Clause License
# Copyright 2024 Meta
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,this list
# of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this
# list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors may
# be used to endorse or promote products derived from this software without specific
# prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
"""
Distributed utils for SDPA context parallel implementation. Slightly modified from
https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5c2/torchtune/training/_distributed.py.
"""
import contextlib
from typing import Callable, Generator, Optional, Union
import torch
from torch import nn
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask
def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
):
"""
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
Returns:
A context manager function that takes an optional context parallel context.
"""
@contextlib.contextmanager
def context(cp_context: Union[Generator[None, None, None], None] = None):
with contextlib.ExitStack() as stack:
if cp_context is not None:
stack.enter_context(
sdpa_kernel(
[
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
]
)
)
stack.enter_context(cp_context)
yield
return context
def get_context_parallel_manager(
*,
world_mesh: torch.distributed.DeviceMesh,
model: nn.Module,
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
"""
Context manager for applying context parallelism to a model. In addition to applying the
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
Args:
world_mesh: Global device mesh.
model: Model to apply context parallelism to.
Returns:
A context manager applying context parallelism if enabled is True. Otherwise a context manager
disabling the math SDPA backend.
Raises:
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
"""
if "cp" not in world_mesh.mesh_dim_names:
raise ValueError(
"Context parallel is enabled but no context parallel device mesh is provided."
)
# TODO: context parallel for multimodal models requires extra work
# if not isinstance(model, TransformerDecoder):
# raise ValueError("Context parallel is only supported for text models")
# model_buffers = list(model.buffers())
# def get_all_buffers(module, prefix=""):
# buffers = {}
# for name, buffer in module.named_buffers(recurse=False):
# full_name = f"{prefix}.{name}" if prefix else name
# buffers[full_name] = buffer
# for name, child in module.named_children():
# child_prefix = f"{prefix}.{name}" if prefix else name
# buffers.update(get_all_buffers(child, child_prefix))
# return buffers
# model_buffers = get_all_buffers(model)
@contextlib.contextmanager
def context(model_inputs: list[torch.Tensor]):
# Create context parallel context if enabled
cp_context = None
if any([isinstance(input, BlockMask) for input in model_inputs]):
raise ValueError(
"Context parallel with flex attention is not yet supported"
)
set_rotate_method("allgather")
cp_context = context_parallel(
world_mesh["cp"],
# buffers=model_inputs + model_buffers,
buffers=model_inputs,
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
buffer_seq_dims=[1] * len(model_inputs),
no_restore_buffers=set(model_inputs),
)
# Create and enter the train context with the optional cp_context
sdpa_context = _get_sdpa_context()
with sdpa_context(cp_context):
yield
return context

View File

@@ -0,0 +1,216 @@
"""Module for Axolotl trainer context parallelism manager and utilities."""
import functools
import inspect
from typing import Callable, Literal
import torch
import torch.distributed as dist
from torch.utils.hooks import RemovableHandle
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
)
from axolotl.utils.ctx_managers.context_parallel.distributed import (
get_context_parallel_manager,
)
from axolotl.utils.ctx_managers.context_parallel.utils import (
AllGatherWithGrad,
apply_context_parallelism,
)
from axolotl.utils.schemas.enums import RingAttnFunc
class ContextParallelContextManager:
"""Context manager for context parallelism operations.
This class provides a context that will automatically apply context parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the context parallelism group using a post-forward hook.
Args:
models: List of models to apply context parallelism to pre- and post- forward
hooks.
backend: Which attention backend to use.
context_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Context parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
"""
def __init__(
self,
models: list[PreTrainedModel],
backend: Literal["sdp_attention", "flash_attention"],
context_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
):
self.models = models
self.backend = backend
self.context_parallel_degree = context_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self._register_ring_attn()
# Store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
if self.backend == "flash_attention":
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial(
apply_context_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
else:
# SPDA device mesh init
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# SDPA context parallel managers
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
world_mesh=world_mesh,
model=model,
)
self.context_parallel_managers.append(ctx_manager)
def __enter__(self):
self._register_model_hooks()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self):
if self.backend == "flash_attention":
# Initialize ring attn for context parallelism
register_ring_attn(
context_parallel_degree=self.context_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(context_parallel_degree=self.context_parallel_degree)
def _register_model_hooks(self):
# Forward pre-hook to apply context parallelism
def cp_flash_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Apply context parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_context_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self._gather_outputs(output)
# Remove padding if it was added
if self.pad_len > 0:
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
if value.size(1) == self.original_seq_len + self.pad_len:
# Slice to remove padding
output[key] = value[:, : self.original_seq_len].contiguous()
return output
def make_sdpa_pre_hook(model_idx: int) -> Callable:
def cp_sdpa_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
with self.context_parallel_managers[model_idx](list(to_shard.values())):
return remaining_args, updated_kwargs
return cp_sdpa_pre_hook
# Register both hooks
for i, model in enumerate(self.models):
if self.backend == "flash_attention":
self.hook_handles.append(
model.register_forward_pre_hook(cp_flash_pre_hook, with_kwargs=True)
)
self.hook_handles.append(
model.register_forward_hook(cp_flash_post_hook)
)
else:
self.hook_handles.append(
model.register_forward_pre_hook(
make_sdpa_pre_hook(i), with_kwargs=True
)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output

View File

@@ -1,28 +1,15 @@
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
import functools
import inspect
"""Utils for context parallel context manager."""
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
from axolotl.utils.schemas.enums import RingAttnFunc
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
def apply_sequence_parallelism(
def apply_context_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
@@ -30,15 +17,15 @@ def apply_sequence_parallelism(
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, int]:
"""
Apply sequence parallelism slicing to a batch.
Apply context parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates
to only keep the last N tokens in the sequence during generation.
to only keep the last N tokens in the input sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
local_rank: Local rank in the context parallel group.
local_world_size: World size of the context parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
@@ -133,7 +120,7 @@ def apply_sequence_parallelism(
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for sequence parallel
# Slice batch for context parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
@@ -159,144 +146,6 @@ def apply_sequence_parallelism(
return batch, original_seq_len, pad_len
class SequenceParallelContextManager:
"""Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
Args:
models: List of models to apply sequence parallelism to pre- and post- forward
hooks.
sequence_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
"""
def __init__(
self,
models: list[nn.Module],
sequence_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self._register_ring_attn()
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
# Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
self._register_model_hooks()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self._gather_outputs(output)
# Remove padding if it was added
if self.pad_len > 0:
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
if value.size(1) == self.original_seq_len + self.pad_len:
# Slice to remove padding
output[key] = value[:, : self.original_seq_len].contiguous()
return output
# Register both hooks
for model in self.models:
self.hook_handles.append(
model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output
class AllGatherWithGrad(torch.autograd.Function):
"""Custom autograd function for all-gather to preserve gradients."""

View File

@@ -486,10 +486,6 @@ def get_dataset_wrapper(
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)
import ipdb
ipdb.set_trace()
if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features

View File

@@ -262,7 +262,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
context_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: RingAttnFunc | None = None
@@ -1179,24 +1179,39 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_grpo_liger_sequence_parallel(cls, data):
def check_grpo_liger_context_parallel(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1
and data.get("context_parallel_degree", 1) > 1
):
raise ValueError("GRPO + SP + Liger not currently supported")
raise ValueError("GRPO + CP + Liger not currently supported")
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
def check_context_parallel_degree(self):
if not self.context_parallel_degree:
self.context_parallel_degree = 1
elif self.context_parallel_degree > 1:
import torch
world_size = torch.cuda.device_count()
if not world_size >= self.context_parallel_degree:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
f"World size ({world_size}) must be greater "
f"than or equal to CP degree ({self.context_parallel_degree})"
)
if not world_size % self.context_parallel_degree == 0:
raise ValueError(
f"SP degree ({self.context_parallel_degree}) "
f"must evenly divide world size ({world_size})"
)
if not (self.flash_attention or self.sdp_attention):
raise ValueError(
"flash_attention: true or sdp_attention: true "
"must be set with context_parallel_degree > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
@@ -1205,21 +1220,22 @@ class AxolotlInputConfig(
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
if self.flash_attention:
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"context_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# TODO: monkeypatch / callback to average losses correctly across CP ranks
# / fix gradient scaling across CP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Context parallelism (SP) is enabled with "
f"context_parallel_degree={self.context_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
@@ -1230,7 +1246,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
if getattr(self, "context_parallel_degree", 1) == 1:
return self
if self.ring_attn_func is not None:

View File

@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_degree
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -479,7 +479,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
data_loader_len * cfg.num_epochs * cfg.context_parallel_degree
)
)
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_degree
/ cfg.batch_size
)
)

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"context_parallel_degree": 1,
# Dtype
"fp16": False,
"bf16": False,

View File

@@ -1,4 +1,4 @@
"""E2E tests for sequence parallelism"""
"""E2E tests for context parallelism"""
from pathlib import Path
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
class TestContextParallelism:
"""Test case for training with context parallelism enabled"""
def _run_sequence_parallel_test(
def _run_context_parallel_test(
self,
temp_dir,
sample_packing=True,
@@ -24,7 +24,7 @@ class TestSequenceParallelism:
ring_attn_func=None,
threshold=2.0,
):
"""Helper method to run sequence parallel tests with different configurations"""
"""Helper method to run context parallel tests with different configurations"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -66,7 +66,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
}
)
@@ -109,7 +109,7 @@ class TestSequenceParallelism:
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(
def test_context_parallel_training(
self,
temp_dir,
sample_packing,
@@ -118,8 +118,8 @@ class TestSequenceParallelism:
ring_attn_func,
threshold,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
"""Test context parallel training with different configurations"""
self._run_context_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,

View File

@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {

View File

@@ -1,4 +1,4 @@
"""Tests for sequence parallelism functionality."""
"""Tests for context parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@@ -54,8 +54,8 @@ def fixture_cfg():
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
def context_parallel_batch():
"""Create a test batch for context parallelism tests."""
batch_size = 1
seq_len = 8
@@ -110,7 +110,7 @@ class TestRingAttention:
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
context_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
@@ -126,7 +126,7 @@ class TestRingAttention:
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
"""Tests for validating context parallelism configurations."""
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
@@ -155,24 +155,24 @@ class TestConfigValidation:
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
{"context_parallel_degree": 2, "flash_attention": True},
{"context_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
# Default context_parallel_degree
({}, {"context_parallel_degree": 1}, True, None),
# Invalid: context_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
{"context_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
# Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
@@ -185,32 +185,32 @@ class TestConfigValidation:
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
"GRPO + CP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
"GRPO + CP + Liger not currently supported",
),
],
ids=[
@@ -222,10 +222,10 @@ class TestConfigValidation:
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
def test_context_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
"""Test various context parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
@@ -261,7 +261,7 @@ class TestConfigValidation:
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
@@ -281,7 +281,7 @@ class TestConfigValidation:
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
@@ -294,8 +294,8 @@ class TestConfigValidation:
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
class TestApplyContextParallelism:
"""Tests for the apply_context_parallelism function."""
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
@@ -324,12 +324,12 @@ class TestApplySequenceParallelism:
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
result, _, _ = apply_context_parallelism(
batch=context_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
@@ -337,17 +337,17 @@ class TestApplySequenceParallelism:
)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
assert result == context_parallel_batch
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
batch = context_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
@@ -366,15 +366,15 @@ class TestApplySequenceParallelism:
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
batch = context_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
@@ -386,14 +386,14 @@ class TestApplySequenceParallelism:
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# def test_batch_zigzag(self, context_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# batch = context_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# result_rank0 = apply_context_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
@@ -401,7 +401,7 @@ class TestApplySequenceParallelism:
# )
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# result_rank1 = apply_context_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
@@ -430,17 +430,17 @@ class TestApplySequenceParallelism:
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
self, mock_get_ring_attn_group, context_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
batch = context_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
apply_context_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
@@ -457,16 +457,14 @@ class TestApplySequenceParallelism:
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
def test_missing_position_ids(self, context_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,

View File

@@ -1,4 +1,8 @@
"""Test cases for tokenizer loading."""
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -9,7 +13,9 @@ from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers:
"""Test class for the load_tokenizer fn"""
"""
test class for the load_tokenizer fn
"""
@enable_hf_offline
def test_default_use_fast(self):
@@ -149,50 +155,6 @@ class TestTokenizers:
):
load_tokenizer(cfg)
def test_mistral_tokenizer_auto_detection(self):
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
def test_mixtral_tokenizer_auto_detection(self):
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "model-hub/Mixtral-8x7B-v0.1",
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
def test_mistral_tokenizer_basic_functionality(self):
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
# Test basic encoding
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert isinstance(tokens, list)
assert len(tokens) > 0
# Test basic decoding
decoded = tokenizer.decode(tokens)
assert isinstance(decoded, str)
# Test token properties are accessible
assert hasattr(tokenizer, "eos_token_id")
assert hasattr(tokenizer, "bos_token_id")
assert isinstance(tokenizer.eos_token_id, int)
assert isinstance(tokenizer.bos_token_id, int)
if __name__ == "__main__":
unittest.main()