Compare commits
7 Commits
mistral-su
...
sdpa-cp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbcc795bb3 | ||
|
|
e34b6f4dfe | ||
|
|
f8f87321bd | ||
|
|
7a88de4fa8 | ||
|
|
aced809989 | ||
|
|
ae73123eae | ||
|
|
10d1e44943 |
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/core/trainers/mixins/context_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
|
||||
@@ -75,7 +75,7 @@ quartodoc:
|
||||
- title: Context Managers
|
||||
desc: Context managers for altering trainer behaviors
|
||||
contents:
|
||||
- utils.ctx_managers.sequence_parallel
|
||||
- utils.ctx_managers.context_parallel
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
@@ -274,7 +274,7 @@ website:
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/context_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -764,13 +764,13 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Context parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# See https://docs.axolotl.ai/docs/context_parallelism.html for more details.
|
||||
context_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
@@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Sequence parallelism
|
||||
- Context parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
@@ -80,14 +80,14 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
## Context parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support sequence parallelism (SP) via the
|
||||
We support context parallelism (SP) via the
|
||||
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||
See our [dedicated guide](context_parallelism.qmd) for more information.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
---
|
||||
title: Sequence Parallelism
|
||||
title: Context Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
Context parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
through a ring communication pattern.
|
||||
|
||||
## When to Use Sequence Parallelism
|
||||
## When to Use Context Parallelism
|
||||
|
||||
Use sequence parallelism when:
|
||||
Use context parallelism when:
|
||||
|
||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||
- You have multiple GPUs available
|
||||
@@ -18,11 +18,11 @@ Use sequence parallelism when:
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
To enable context parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
context_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,23 +30,23 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When sequence parallelism is enabled:
|
||||
When context parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
|
||||
To use sequence parallelism, you need:
|
||||
To use context parallelism, you need:
|
||||
|
||||
- Multiple GPUs (at least 2)
|
||||
- The `ring-flash-attn` package. Install with:
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -79,22 +79,22 @@ ring_attn_func:
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
## Sample Packing with Context Parallelism
|
||||
|
||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
|
||||
1. Samples are first packed together
|
||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||
2. The packed sequences are then divided across GPUs in the context parallel group
|
||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and no context parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
@@ -20,7 +20,6 @@ datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.1
|
||||
hf_xet==1.1.2
|
||||
mistral-common[hf-hub]==1.6.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
@@ -73,7 +73,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
context_parallel_degree=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
|
||||
@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -54,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
context_parallel=self.cfg.context_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
|
||||
@@ -7,11 +7,13 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Callable, Literal, Optional
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# SPDA device mesh init
|
||||
import torch.distributed as dist
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // 2,
|
||||
2,
|
||||
)
|
||||
self.world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||
) -> torch.Tensor:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=self.world_mesh,
|
||||
model=model,
|
||||
)
|
||||
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
|
||||
with ctx_manager(list(to_shard.values())):
|
||||
super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
|
||||
@@ -8,7 +8,7 @@ from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOContextParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,10 +23,10 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
cls, context_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]:
|
||||
if context_parallel:
|
||||
return AxolotlGRPOContextParallelTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
@@ -69,8 +69,8 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
if cfg.context_parallel_degree > 1:
|
||||
grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
context_parallel_degree: int | None = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Repeat random sampler (similar to the one implemented in
|
||||
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
|
||||
sequence parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
sequence parallel group.
|
||||
context parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
context parallel group.
|
||||
"""
|
||||
|
||||
from typing import Iterator, Sized
|
||||
@@ -10,26 +10,26 @@ import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with sequence parallelism.
|
||||
class ContextParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with context parallelism.
|
||||
|
||||
This sampler ensures:
|
||||
- Ranks in the same sequence parallel (SP) group receive identical data.
|
||||
- Ranks in the same context parallel (SP) group receive identical data.
|
||||
- Each index is repeated multiple times for sampling different completions.
|
||||
- Entire batches are repeated for reuse in multiple updates.
|
||||
- Data is properly distributed across SP groups.
|
||||
- Data is properly distributed across CP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
In the table below, the values represent dataset indices. Each CP group has
|
||||
`context_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
CP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
Context Parallel Groups
|
||||
| SP0 | SP1 |
|
||||
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
|
||||
global_step step <---> mini_repeat_count=3
|
||||
<----------> batch_size=2 per SP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
|
||||
<----------> batch_size=2 per CP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU
|
||||
|
|
||||
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
|
||||
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
|
||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
context_parallel_degree: Number of ranks in a context parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
context_parallel_degree: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -76,16 +76,16 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
# Context parallelism parameters
|
||||
self.context_parallel_degree = context_parallel_degree
|
||||
self.num_sp_groups = world_size // context_parallel_degree
|
||||
self.sp_group_id = rank // context_parallel_degree
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
self.total_size = self.num_samples
|
||||
|
||||
# Calculate effective number of samples per SP group
|
||||
# Calculate effective number of samples per CP group
|
||||
if (
|
||||
self.drop_last
|
||||
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
|
||||
@@ -125,8 +125,8 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
padding = indices[: self.batch_size - len(indices) % self.batch_size]
|
||||
indices += padding
|
||||
|
||||
# Subsample based on SP group ID
|
||||
# Each SP group gets distinct batches of data
|
||||
# Subsample based on CP group ID
|
||||
# Each CP group gets distinct batches of data
|
||||
batch_indices = []
|
||||
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
|
||||
start_idx = i + self.sp_group_id * self.batch_size
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
|
||||
"""Axolotl GRPO trainers (with and without context parallelism handling)"""
|
||||
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -42,7 +41,7 @@ from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
@@ -59,45 +58,9 @@ class AxolotlGRPOTrainer(
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def get_train_dataloader(self):
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size
|
||||
* self.args.steps_per_generation, # < this is the change
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for context parallelism handling"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -134,11 +97,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
||||
)
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
# Get number of CP groups (number of processes divided by CP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.context_parallel_degree
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
# Calculate batch size per CP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
possible_values = [
|
||||
n_gen
|
||||
@@ -148,7 +111,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The batch size per SP group ({num_sp_groups} x "
|
||||
f"The batch size per CP group ({num_sp_groups} x "
|
||||
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
|
||||
f"the number of generations per prompt ({self.num_generations}). Given "
|
||||
"the current configuration, the valid values for the number of "
|
||||
@@ -156,7 +119,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
|
||||
if self.args.eval_strategy != "no":
|
||||
# If sequence parallelism is enabled, calculate batch size per SP group
|
||||
# If context parallelism is enabled, calculate batch size per CP group
|
||||
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
|
||||
possible_values = [
|
||||
n_gen
|
||||
@@ -166,8 +129,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"With context parallelism (degree {self.args.context_parallel_degree}), "
|
||||
f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
f"the valid values for the number of generations are: {possible_values}."
|
||||
@@ -180,7 +143,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
self.local_world_size = 1
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
# Initialize the SP group
|
||||
# Initialize the CP group
|
||||
self.sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
@@ -196,16 +159,16 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return SequenceParallelRepeatRandomSampler(
|
||||
return ContextParallelRepeatRandomSampler(
|
||||
dataset=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.sequence_parallel_degree,
|
||||
// self.args.context_parallel_degree,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
context_parallel_degree=self.args.context_parallel_degree,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -263,11 +226,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
# Return unprepared dataloader if using context parallelism
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -340,21 +303,21 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate sequence parallel group information
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate context parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
context_parallel_degree = self.args.context_parallel_degree
|
||||
num_sp_groups = world_size // context_parallel_degree
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
# Since processes in the same CP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each CP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
# Get the first process from each CP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * context_parallel_degree
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
# Extract prompts from this CP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each CP group
|
||||
group_prompts = all_prompts_text[
|
||||
group_leader_rank
|
||||
* len(prompts_text) : (group_leader_rank + 1)
|
||||
@@ -367,7 +330,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
:: self.num_generations * self.args.context_parallel_degree
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -384,28 +347,28 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
len(all_prompts_text) // self.args.context_parallel_degree
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
# Determine the appropriate slice based on context parallelism
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate CP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this SP group
|
||||
# Calculate the start index for this CP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same SP group get the same data slice
|
||||
# All ranks in the same CP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
else:
|
||||
# Original behavior for non-sequence parallel case
|
||||
# Original behavior for non-context parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
@@ -615,20 +578,20 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate CP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this SP group
|
||||
# Calculate the start index for this CP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same SP group get the same data slice
|
||||
# All ranks in the same CP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
else:
|
||||
# Original behavior for non-sequence parallel case
|
||||
# Original behavior for non-context parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
|
||||
@@ -64,10 +64,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
|
||||
@@ -2,16 +2,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer,
|
||||
)
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
@@ -31,622 +23,239 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
# Constants
|
||||
LLAMA_TOKENIZER_CLASSES = {
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
}
|
||||
|
||||
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
||||
|
||||
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
|
||||
GPTNEOX_PAD_TOKEN = "[PAD]"
|
||||
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
||||
|
||||
|
||||
class MistralTokenizerWrapper:
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
|
||||
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||
and return the path to the modified tokenizer.
|
||||
|
||||
This only works with reserved tokens that were added to the tokenizer, not tokens
|
||||
already part of the vocab.
|
||||
|
||||
Args:
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
|
||||
Returns:
|
||||
Path to the modified tokenizer directory
|
||||
|
||||
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||
"""
|
||||
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
||||
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
|
||||
self.mistral_tokenizer = mistral_tokenizer
|
||||
self.model_id = model_id
|
||||
self._system_prompt = None
|
||||
self.padding_side = "right" # Default padding side
|
||||
self.chat_template = None
|
||||
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
|
||||
# Cache token IDs by inspecting the actual tokenizer
|
||||
self._token_ids = self._discover_token_ids()
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
|
||||
# Try to load system prompt if available
|
||||
try:
|
||||
self._system_prompt = self._load_system_prompt(
|
||||
model_id, "SYSTEM_PROMPT.txt"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.debug(f"Could not load system prompt: {e}")
|
||||
|
||||
def _discover_token_ids(self) -> Dict[str, int]:
|
||||
"""Discover the actual token IDs used by this Mistral tokenizer."""
|
||||
token_ids = {}
|
||||
|
||||
try:
|
||||
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
|
||||
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
|
||||
|
||||
# Get BOS token ID from instruct_tokenizer
|
||||
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
|
||||
|
||||
# Get token IDs from the underlying Tekkenizer
|
||||
if hasattr(instruct_tokenizer, "tokenizer"):
|
||||
tekkenizer = instruct_tokenizer.tokenizer
|
||||
|
||||
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
|
||||
if hasattr(tekkenizer, "bos_id"):
|
||||
token_ids["bos_token_id"] = tekkenizer.bos_id
|
||||
|
||||
# Get vocab size to help find EOS token
|
||||
vocab_size = getattr(tekkenizer, "_vocab_size", None)
|
||||
|
||||
# Check special tokens
|
||||
if hasattr(tekkenizer, "_all_special_tokens"):
|
||||
special_tokens = tekkenizer._all_special_tokens
|
||||
keys = (
|
||||
list(special_tokens.keys())
|
||||
if hasattr(special_tokens, "keys")
|
||||
else special_tokens
|
||||
)
|
||||
LOG.debug(f"Special tokens available: {keys}")
|
||||
|
||||
# Try to find EOS token in special tokens
|
||||
if hasattr(special_tokens, "get"):
|
||||
# Common EOS token patterns
|
||||
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
|
||||
if eos_pattern in special_tokens:
|
||||
token_ids["eos_token_id"] = special_tokens[
|
||||
eos_pattern
|
||||
]
|
||||
break
|
||||
|
||||
# Check special tokens reverse vocab
|
||||
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
|
||||
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
|
||||
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
|
||||
|
||||
# Look for common special token IDs
|
||||
for token_id, token_str in reverse_vocab.items():
|
||||
if token_str in ["</s>", "<|endoftext|>"]:
|
||||
token_ids["eos_token_id"] = token_id
|
||||
elif token_str in ["<unk>", "<UNK>"]:
|
||||
token_ids["unk_token_id"] = token_id
|
||||
|
||||
# If we have vocab_size, EOS is often vocab_size - 1 or similar
|
||||
if "eos_token_id" not in token_ids and vocab_size:
|
||||
# Common patterns: EOS could be 2, vocab_size-1, or other values
|
||||
# Let's try a safer approach by checking what tokens decode to
|
||||
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
|
||||
try:
|
||||
# Try to decode and see if it looks like EOS
|
||||
decoded = tekkenizer.decode([candidate_id])
|
||||
if decoded in ["</s>", "<|endoftext|>", ""]:
|
||||
token_ids["eos_token_id"] = candidate_id
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
LOG.debug(f"Could not discover token IDs: {e}")
|
||||
|
||||
# Set reasonable defaults for any missing token IDs
|
||||
token_ids.setdefault("bos_token_id", 1)
|
||||
token_ids.setdefault("eos_token_id", 2)
|
||||
token_ids.setdefault("unk_token_id", 0)
|
||||
token_ids.setdefault(
|
||||
"pad_token_id", token_ids["eos_token_id"]
|
||||
) # Use EOS as pad
|
||||
|
||||
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
|
||||
return token_ids
|
||||
|
||||
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
|
||||
"""Load system prompt from HuggingFace Hub"""
|
||||
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
|
||||
"""Encode text to token IDs"""
|
||||
if isinstance(text, str):
|
||||
# For simple string encoding, create a user message
|
||||
messages = []
|
||||
if self._system_prompt and add_special_tokens:
|
||||
messages.append(SystemMessage(content=self._system_prompt))
|
||||
messages.append(UserMessage(content=text))
|
||||
|
||||
tokenized = self.mistral_tokenizer.encode_chat_completion(
|
||||
ChatCompletionRequest(messages=messages)
|
||||
)
|
||||
return tokenized.tokens
|
||||
else:
|
||||
raise ValueError("MistralTokenizer wrapper only supports string input")
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[List[int], torch.Tensor],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> str:
|
||||
"""Decode token IDs to text"""
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.tolist()
|
||||
return self.mistral_tokenizer.decode(token_ids)
|
||||
|
||||
def __call__(self, text: str, **kwargs):
|
||||
"""Make the tokenizer callable like HF tokenizers"""
|
||||
tokens = self.encode(text, **kwargs)
|
||||
return {"input_ids": torch.tensor([tokens])}
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self._token_ids["eos_token_id"]
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
return self._token_ids["bos_token_id"]
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self._token_ids["pad_token_id"]
|
||||
|
||||
@property
|
||||
def unk_token_id(self):
|
||||
return self._token_ids["unk_token_id"]
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return "</s>" # Standard Mistral EOS token
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return "<s>" # Standard Mistral BOS token
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self.eos_token # Use EOS as pad token
|
||||
|
||||
@property
|
||||
def unk_token(self):
|
||||
return "<unk>" # Standard UNK token
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
# Create a mock class for compatibility checks
|
||||
class MistralTokenizerWrapperClass:
|
||||
__name__ = "MistralTokenizerWrapper"
|
||||
|
||||
return MistralTokenizerWrapperClass
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
|
||||
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
|
||||
LOG.warning(
|
||||
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
|
||||
)
|
||||
return 0
|
||||
|
||||
def add_tokens(self, tokens) -> int:
|
||||
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
|
||||
LOG.warning(
|
||||
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
class TokenizerFileModifier:
|
||||
"""Handles modification of tokenizer files for token overrides."""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||
):
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.token_mappings = token_mappings
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
|
||||
def modify_and_save(self) -> str:
|
||||
"""Modify tokenizer files and return path to modified tokenizer."""
|
||||
os.makedirs(self.tokenizer_dir, exist_ok=True)
|
||||
|
||||
if is_local_main_process():
|
||||
self._perform_modifications()
|
||||
barrier()
|
||||
|
||||
return self.tokenizer_dir
|
||||
|
||||
def _perform_modifications(self):
|
||||
"""Perform the actual file modifications."""
|
||||
# Load and save tokenizer to output directory
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer_path, use_fast=True
|
||||
)
|
||||
temp_tokenizer.save_pretrained(self.tokenizer_dir)
|
||||
|
||||
# Convert token mappings to proper format
|
||||
# Get the token IDs and map them to their new values
|
||||
token_id_mappings = {
|
||||
int(token_id): new_value
|
||||
for token_id, new_value in self.token_mappings.items()
|
||||
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||
}
|
||||
|
||||
# Update both tokenizer files
|
||||
self._update_tokenizer_config(token_id_mappings)
|
||||
self._update_tokenizer_json(token_id_mappings)
|
||||
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
|
||||
"""Update tokenizer_config.json with new token mappings."""
|
||||
config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
return
|
||||
# Update added_tokens_decoder
|
||||
if "added_tokens_decoder" in config_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str][
|
||||
"content"
|
||||
] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
# Write the updated config back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
if "added_tokens_decoder" in config_data:
|
||||
self._update_added_tokens_decoder(config_data, token_id_mappings)
|
||||
# 2. Update tokenizer.json - added_tokens
|
||||
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
# Update added_tokens
|
||||
if "added_tokens" in tokenizer_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} not found in added_tokens"
|
||||
)
|
||||
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
|
||||
if entry_id == token_id:
|
||||
del tokenizer_data["model"]["vocab"][entry_val]
|
||||
tokenizer_data["model"]["vocab"][new_value] = token_id
|
||||
break
|
||||
|
||||
def _update_added_tokens_decoder(
|
||||
self, config_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update the added_tokens_decoder section."""
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
# Write the updated tokenizer data back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
|
||||
"""Update tokenizer.json with new token mappings."""
|
||||
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
|
||||
if not os.path.exists(tokenizer_json_path):
|
||||
return
|
||||
|
||||
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
|
||||
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
|
||||
|
||||
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
def _update_added_tokens_list(
|
||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update the added_tokens list in tokenizer.json."""
|
||||
if "added_tokens" not in tokenizer_data:
|
||||
return
|
||||
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Token ID {token_id} not found in added_tokens")
|
||||
|
||||
def _update_vocab_mappings(
|
||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update vocab mappings in tokenizer.json."""
|
||||
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
|
||||
return
|
||||
|
||||
vocab = tokenizer_data["model"]["vocab"]
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
# Find and update the vocab entry
|
||||
for entry_val, entry_id in list(vocab.items()):
|
||||
if entry_id == token_id:
|
||||
del vocab[entry_val]
|
||||
vocab[new_value] = token_id
|
||||
break
|
||||
|
||||
|
||||
class TokenizerConfiguration:
|
||||
"""Handles tokenizer configuration and initialization."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||
"""Load Mistral tokenizer from model configuration."""
|
||||
# Instantiate Mistral tokenizer
|
||||
model_id = self.cfg.base_model
|
||||
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
||||
|
||||
# Wrap it for compatibility
|
||||
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_tokenizer_class(self):
|
||||
"""Get the appropriate tokenizer class."""
|
||||
if self.cfg.tokenizer_type:
|
||||
return getattr(transformers, self.cfg.tokenizer_type)
|
||||
return AutoTokenizer
|
||||
|
||||
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
|
||||
"""Build tokenizer initialization kwargs."""
|
||||
kwargs = {}
|
||||
if self.cfg.tokenizer_legacy is not None:
|
||||
kwargs["legacy"] = self.cfg.tokenizer_legacy
|
||||
return kwargs
|
||||
|
||||
def get_tokenizer_path(self) -> str:
|
||||
"""Get the tokenizer path, applying overrides if needed."""
|
||||
tokenizer_path = self.cfg.tokenizer_config
|
||||
|
||||
if self.cfg.added_tokens_overrides:
|
||||
modifier = TokenizerFileModifier(
|
||||
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
|
||||
)
|
||||
tokenizer_path = modifier.modify_and_save()
|
||||
|
||||
return tokenizer_path
|
||||
|
||||
def should_use_fast_tokenizer(self) -> bool:
|
||||
"""Determine if fast tokenizer should be used."""
|
||||
return (
|
||||
self.cfg.tokenizer_use_fast
|
||||
if self.cfg.tokenizer_use_fast is not None
|
||||
else True
|
||||
)
|
||||
|
||||
|
||||
class TokenizerPostProcessor:
|
||||
"""Handles post-processing configuration of loaded tokenizers."""
|
||||
|
||||
def __init__(self, tokenizer, cfg):
|
||||
self.tokenizer = tokenizer
|
||||
self.cfg = cfg
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
def apply_all_configurations(self):
|
||||
"""Apply all post-processing configurations to the tokenizer."""
|
||||
# Skip most configurations for Mistral wrapper
|
||||
if isinstance(self.tokenizer, MistralTokenizerWrapper):
|
||||
self._configure_mistral_wrapper()
|
||||
return
|
||||
|
||||
self._configure_padding_token()
|
||||
self._configure_gptneox_settings()
|
||||
self._configure_mistral_padding()
|
||||
self._configure_qwen_tokens()
|
||||
self._add_special_tokens()
|
||||
self._add_regular_tokens()
|
||||
self._configure_chat_template()
|
||||
|
||||
def _configure_mistral_wrapper(self):
|
||||
"""Apply limited configurations for Mistral wrapper."""
|
||||
# Set padding side if needed
|
||||
if (
|
||||
self.cfg.is_mistral_derived_model
|
||||
and self.cfg.flash_attention
|
||||
and not self.cfg.sample_packing
|
||||
):
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# Configure chat template for Mistral
|
||||
self._configure_chat_template()
|
||||
|
||||
def _configure_padding_token(self):
|
||||
"""Configure padding token for Llama-based tokenizers."""
|
||||
if (
|
||||
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
|
||||
and hasattr(self.tokenizer, "pad_token")
|
||||
and not self.tokenizer.pad_token
|
||||
):
|
||||
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
def _configure_gptneox_settings(self):
|
||||
"""Configure GPTNeoX-specific settings."""
|
||||
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def _configure_mistral_padding(self):
|
||||
"""Configure left padding for Mistral models with Flash Attention."""
|
||||
if (
|
||||
self.cfg.is_mistral_derived_model
|
||||
and self.cfg.flash_attention
|
||||
and not self.cfg.sample_packing
|
||||
):
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
def _configure_qwen_tokens(self):
|
||||
"""Configure special tokens for Qwen models."""
|
||||
if not self.cfg.is_qwen_derived_model:
|
||||
return
|
||||
|
||||
# Set token IDs
|
||||
token_id_attributes = [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
"unk_token_id",
|
||||
]
|
||||
for attr_name in token_id_attributes:
|
||||
if getattr(self.tokenizer, attr_name) is None:
|
||||
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
|
||||
|
||||
# Set token strings
|
||||
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_name_attributes:
|
||||
if getattr(self.tokenizer, attr_name) is None:
|
||||
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
|
||||
|
||||
def _add_special_tokens(self):
|
||||
"""Add special tokens from configuration."""
|
||||
if not self.cfg.special_tokens:
|
||||
return
|
||||
|
||||
special_tokens_dict = self.cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens_dict.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
|
||||
self._validate_and_add_special_tokens(special_tokens_dict)
|
||||
self._update_post_processor_if_needed(special_tokens_dict)
|
||||
self._add_additional_special_tokens_if_present(additional_special_tokens)
|
||||
|
||||
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
|
||||
"""Validate special tokens for adapter training and add them."""
|
||||
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
|
||||
|
||||
for key, value in special_tokens.items():
|
||||
self._validate_token_for_adapter(key, value, lora_modules_to_save)
|
||||
self.tokenizer.add_special_tokens(
|
||||
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
def _validate_token_for_adapter(
|
||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
||||
):
|
||||
"""Validate a single token for adapter training requirements."""
|
||||
if not self._should_validate_token_for_adapter(
|
||||
key, value, lora_modules_to_save
|
||||
):
|
||||
return
|
||||
|
||||
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{modules_str}] "
|
||||
f"when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
def _should_validate_token_for_adapter(
|
||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
||||
) -> bool:
|
||||
"""Check if token should be validated for adapter configuration."""
|
||||
if key == "pad_token" or not self.cfg.adapter:
|
||||
return False
|
||||
|
||||
current_token = getattr(self.tokenizer, key)
|
||||
token_changed = current_token is None or current_token != value
|
||||
token_is_multi_char = (
|
||||
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
|
||||
)
|
||||
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
|
||||
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
|
||||
return token_changed and token_is_multi_char and lora_modules_missing
|
||||
|
||||
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
|
||||
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
|
||||
has_bos_and_eos = (
|
||||
"bos_token" in special_tokens and "eos_token" in special_tokens
|
||||
)
|
||||
is_fast_llama = (
|
||||
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
|
||||
)
|
||||
|
||||
if is_fast_llama and has_bos_and_eos:
|
||||
self.tokenizer.update_post_processor()
|
||||
|
||||
def _add_additional_special_tokens_if_present(
|
||||
self, additional_special_tokens: Optional[List[str]]
|
||||
):
|
||||
"""Add additional special tokens if they exist."""
|
||||
if additional_special_tokens is not None:
|
||||
self.tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
def _add_regular_tokens(self):
|
||||
"""Add regular (non-special) tokens from configuration."""
|
||||
if self.cfg.tokens:
|
||||
self.tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in self.cfg.tokens
|
||||
]
|
||||
)
|
||||
|
||||
def _configure_chat_template(self):
|
||||
"""Configure chat template if specified."""
|
||||
if not self.cfg.chat_template:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self._should_replace_default_system_message():
|
||||
chat_template_string = chat_template_string.replace(
|
||||
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
|
||||
)
|
||||
|
||||
self.tokenizer.chat_template = chat_template_string
|
||||
|
||||
def _should_replace_default_system_message(self) -> bool:
|
||||
"""Check if default system message should be replaced."""
|
||||
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
|
||||
barrier()
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config.
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
This function handles the complete tokenizer loading pipeline:
|
||||
- Check if Mistral tokenizer should be used
|
||||
- Configure tokenizer parameters and get the appropriate class
|
||||
- Handle token file modifications if needed
|
||||
- Initialize the tokenizer with the correct parameters
|
||||
- Apply all post-processing configurations (padding, special tokens, etc.)
|
||||
- Set up chat templates and logging
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
Returns:
|
||||
Fully configured tokenizer instance.
|
||||
"""
|
||||
# Configure tokenizer parameters
|
||||
config = TokenizerConfiguration(cfg)
|
||||
# Set base tokenizer path
|
||||
tokenizer_path = cfg.tokenizer_config
|
||||
|
||||
# Check if we should use Mistral tokenizer
|
||||
try:
|
||||
tokenizer = config.load_mistral_tokenizer()
|
||||
except:
|
||||
# Standard tokenizer loading
|
||||
tokenizer_cls = config.get_tokenizer_class()
|
||||
tokenizer_path = config.get_tokenizer_path()
|
||||
use_fast = config.should_use_fast_tokenizer()
|
||||
tokenizer_kwargs = config.get_tokenizer_kwargs()
|
||||
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
)
|
||||
|
||||
# Apply all post-processing configurations
|
||||
post_processor = TokenizerPostProcessor(tokenizer, cfg)
|
||||
post_processor.apply_all_configurations()
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
]
|
||||
and hasattr(tokenizer, "pad_token")
|
||||
and not tokenizer.pad_token
|
||||
):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Mistral's official FA implementation requires left padding
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
if cfg.is_qwen_derived_model:
|
||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||
for attr_name in token_ids:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
||||
|
||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_names:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||
for k, val in special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||
and cfg.adapter
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
or not all(
|
||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
# If we add bos_token and eos_token, we need to update the post processor to
|
||||
# handle them correctly.
|
||||
# https://github.com/huggingface/transformers/pull/24132
|
||||
bos_or_eos_in_special_tokens = (
|
||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||
)
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in (
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizerFast",
|
||||
)
|
||||
and bos_or_eos_in_special_tokens
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in cfg.tokens
|
||||
]
|
||||
)
|
||||
|
||||
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||
# are new tokens.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ```py
|
||||
# special_tokens:
|
||||
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||
# ```
|
||||
if additional_special_tokens is not None:
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
if is_main_process(use_environ=True):
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
@@ -654,4 +263,19 @@ def load_tokenizer(cfg):
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if cfg.chat_template:
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||
chat_template_string = chat_template_string.replace(
|
||||
"You are a helpful assistant.", cfg.default_system_message
|
||||
)
|
||||
|
||||
tokenizer.chat_template = chat_template_string
|
||||
else:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
their context parallel version of Flash Attention 2.
|
||||
|
||||
We also provide some patches for accelerate functions to prepare the dataloader for
|
||||
sequence parallelism training.
|
||||
context parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
@@ -13,9 +13,9 @@ import inspect
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -63,15 +63,15 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
context_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
context_parallel_degree: Context parallelism factor.
|
||||
heads_k_stride: Context parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
@@ -80,28 +80,18 @@ def register_ring_attn(
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
LOG.info(
|
||||
"Enabling ring attention context parallelism: "
|
||||
f"each sequence will be processed across {context_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
# Assign ranks to sequence parallel groups
|
||||
# Assign ranks to context parallel groups
|
||||
group_assignments = {}
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
for i in range(world_size // context_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
i * context_parallel_degree,
|
||||
(i + 1) * context_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
@@ -113,9 +103,7 @@ def register_ring_attn(
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
LOG.info(f"Context parallel group assignments: {group_assignments}")
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
@@ -150,7 +138,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
|
||||
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the CP degree.
|
||||
|
||||
Raies:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
@@ -176,15 +164,15 @@ def patch_prepare_data_loader():
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
|
||||
accelerate.data_loader.prepare_data_loader = patched_function
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for CP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes sequence parallelism with the specified degree.
|
||||
that includes context parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
||||
context_parallel_degree (int): The degree of context parallelism to use.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
@@ -199,11 +187,11 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
):
|
||||
return self.state.ds_device_mesh
|
||||
|
||||
# Create device mesh with sequence parallelism
|
||||
# Create device mesh with context parallelism
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // sequence_parallel_degree,
|
||||
sequence_parallel_degree,
|
||||
world_size // context_parallel_degree,
|
||||
context_parallel_degree,
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
@@ -221,5 +209,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
|
||||
LOG.info(
|
||||
"Successfully patched Accelerator._prepare_device_mesh "
|
||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||
f"with context_parallel_degree={context_parallel_degree}"
|
||||
)
|
||||
|
||||
@@ -67,10 +67,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
return empty
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
|
||||
@@ -25,14 +25,13 @@ from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders import (
|
||||
ModelLoader,
|
||||
load_processor,
|
||||
load_tokenizer,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.ctx_managers import ContextParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
@@ -148,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
@@ -202,15 +201,20 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
|
||||
# Models to enter context parallel manager for
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
|
||||
# Attention backend
|
||||
backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
|
||||
|
||||
stack.enter_context(
|
||||
SequenceParallelContextManager(
|
||||
ContextParallelContextManager(
|
||||
models=models,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
backend=backend,
|
||||
context_parallel_degree=cfg.context_parallel_degree,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
@@ -224,7 +228,7 @@ def execute_training(
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PreTrainedModel,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
@@ -375,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
processor: ProcessorMixin | None,
|
||||
):
|
||||
@@ -429,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
|
||||
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PreTrainedModel,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
@@ -472,7 +476,7 @@ def handle_untrained_tokens_fix(
|
||||
|
||||
|
||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||
Trainer,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Init for context manager submodule"""
|
||||
"""Init for context manager submodule."""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
from .context_parallel.manager import ContextParallelContextManager
|
||||
|
||||
from .sequence_parallel import SequenceParallelContextManager
|
||||
__all__ = ["ContextParallelContextManager"]
|
||||
|
||||
146
src/axolotl/utils/ctx_managers/context_parallel/distributed.py
Normal file
146
src/axolotl/utils/ctx_managers/context_parallel/distributed.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# BSD 3-Clause License
|
||||
|
||||
# Copyright 2024 Meta
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,this list
|
||||
# of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
# list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its contributors may
|
||||
# be used to endorse or promote products derived from this software without specific
|
||||
# prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
||||
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
|
||||
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
||||
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
||||
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
||||
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
||||
# DAMAGE.
|
||||
|
||||
"""
|
||||
Distributed utils for SDPA context parallel implementation. Slightly modified from
|
||||
https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5c2/torchtune/training/_distributed.py.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
|
||||
def _get_sdpa_context() -> (
|
||||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
||||
):
|
||||
"""
|
||||
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
|
||||
|
||||
Returns:
|
||||
A context manager function that takes an optional context parallel context.
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(cp_context: Union[Generator[None, None, None], None] = None):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if cp_context is not None:
|
||||
stack.enter_context(
|
||||
sdpa_kernel(
|
||||
[
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.CUDNN_ATTENTION,
|
||||
]
|
||||
)
|
||||
)
|
||||
stack.enter_context(cp_context)
|
||||
|
||||
yield
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def get_context_parallel_manager(
|
||||
*,
|
||||
world_mesh: torch.distributed.DeviceMesh,
|
||||
model: nn.Module,
|
||||
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
|
||||
"""
|
||||
Context manager for applying context parallelism to a model. In addition to applying the
|
||||
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
|
||||
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
|
||||
|
||||
Args:
|
||||
world_mesh: Global device mesh.
|
||||
model: Model to apply context parallelism to.
|
||||
|
||||
Returns:
|
||||
A context manager applying context parallelism if enabled is True. Otherwise a context manager
|
||||
disabling the math SDPA backend.
|
||||
|
||||
Raises:
|
||||
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
|
||||
"""
|
||||
|
||||
if "cp" not in world_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"Context parallel is enabled but no context parallel device mesh is provided."
|
||||
)
|
||||
# TODO: context parallel for multimodal models requires extra work
|
||||
# if not isinstance(model, TransformerDecoder):
|
||||
# raise ValueError("Context parallel is only supported for text models")
|
||||
# model_buffers = list(model.buffers())
|
||||
|
||||
# def get_all_buffers(module, prefix=""):
|
||||
# buffers = {}
|
||||
# for name, buffer in module.named_buffers(recurse=False):
|
||||
# full_name = f"{prefix}.{name}" if prefix else name
|
||||
# buffers[full_name] = buffer
|
||||
|
||||
# for name, child in module.named_children():
|
||||
# child_prefix = f"{prefix}.{name}" if prefix else name
|
||||
# buffers.update(get_all_buffers(child, child_prefix))
|
||||
|
||||
# return buffers
|
||||
|
||||
# model_buffers = get_all_buffers(model)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(model_inputs: list[torch.Tensor]):
|
||||
# Create context parallel context if enabled
|
||||
cp_context = None
|
||||
if any([isinstance(input, BlockMask) for input in model_inputs]):
|
||||
raise ValueError(
|
||||
"Context parallel with flex attention is not yet supported"
|
||||
)
|
||||
set_rotate_method("allgather")
|
||||
|
||||
cp_context = context_parallel(
|
||||
world_mesh["cp"],
|
||||
# buffers=model_inputs + model_buffers,
|
||||
buffers=model_inputs,
|
||||
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
|
||||
buffer_seq_dims=[1] * len(model_inputs),
|
||||
no_restore_buffers=set(model_inputs),
|
||||
)
|
||||
|
||||
# Create and enter the train context with the optional cp_context
|
||||
sdpa_context = _get_sdpa_context()
|
||||
|
||||
with sdpa_context(cp_context):
|
||||
yield
|
||||
|
||||
return context
|
||||
216
src/axolotl/utils/ctx_managers/context_parallel/manager.py
Normal file
216
src/axolotl/utils/ctx_managers/context_parallel/manager.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Module for Axolotl trainer context parallelism manager and utilities."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Callable, Literal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import (
|
||||
get_context_parallel_manager,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.utils import (
|
||||
AllGatherWithGrad,
|
||||
apply_context_parallelism,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
class ContextParallelContextManager:
|
||||
"""Context manager for context parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply context parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the context parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply context parallelism to pre- and post- forward
|
||||
hooks.
|
||||
backend: Which attention backend to use.
|
||||
context_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Context parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[PreTrainedModel],
|
||||
backend: Literal["sdp_attention", "flash_attention"],
|
||||
context_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.backend = backend
|
||||
self.context_parallel_degree = context_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
if self.backend == "flash_attention":
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Create a partially applied version of the apply_context_parallelism function
|
||||
self.apply_context_parallelism = functools.partial(
|
||||
apply_context_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
else:
|
||||
# SPDA device mesh init
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // self.context_parallel_degree,
|
||||
self.context_parallel_degree,
|
||||
)
|
||||
world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
# SDPA context parallel managers
|
||||
self.context_parallel_managers = []
|
||||
for model in models:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=world_mesh,
|
||||
model=model,
|
||||
)
|
||||
self.context_parallel_managers.append(ctx_manager)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
if self.backend == "flash_attention":
|
||||
# Initialize ring attn for context parallelism
|
||||
register_ring_attn(
|
||||
context_parallel_degree=self.context_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(context_parallel_degree=self.context_parallel_degree)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply context parallelism
|
||||
def cp_flash_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Apply context parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_context_parallelism(updated_kwargs)
|
||||
)
|
||||
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
||||
def cp_sdpa_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
|
||||
|
||||
with self.context_parallel_managers[model_idx](list(to_shard.values())):
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
return cp_sdpa_pre_hook
|
||||
|
||||
# Register both hooks
|
||||
for i, model in enumerate(self.models):
|
||||
if self.backend == "flash_attention":
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(cp_flash_pre_hook, with_kwargs=True)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(cp_flash_post_hook)
|
||||
)
|
||||
else:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
make_sdpa_pre_hook(i), with_kwargs=True
|
||||
)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
@@ -1,28 +1,15 @@
|
||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
"""Utils for context parallel context manager."""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
|
||||
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
|
||||
def apply_sequence_parallelism(
|
||||
def apply_context_parallelism(
|
||||
batch: dict[str, torch.Tensor],
|
||||
local_rank: int,
|
||||
local_world_size: int,
|
||||
@@ -30,15 +17,15 @@ def apply_sequence_parallelism(
|
||||
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
|
||||
) -> tuple[dict[str, torch.Tensor], int, int]:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
Apply context parallelism slicing to a batch.
|
||||
|
||||
Special handling is implemented for integer logits_to_keep, which indicates
|
||||
to only keep the last N tokens in the sequence during generation.
|
||||
to only keep the last N tokens in the input sequence during generation.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||
local_rank: Local rank in the sequence parallel group.
|
||||
local_world_size: World size of the sequence parallel group.
|
||||
local_rank: Local rank in the context parallel group.
|
||||
local_world_size: World size of the context parallel group.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused, but
|
||||
related to above TODO.
|
||||
@@ -133,7 +120,7 @@ def apply_sequence_parallelism(
|
||||
# Update the total sequence length after padding
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Slice batch for sequence parallel
|
||||
# Slice batch for context parallel
|
||||
for key in batch:
|
||||
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
|
||||
continue
|
||||
@@ -159,144 +146,6 @@ def apply_sequence_parallelism(
|
||||
return batch, original_seq_len, pad_len
|
||||
|
||||
|
||||
class SequenceParallelContextManager:
|
||||
"""Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the sequence parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
sequence_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
|
||||
# Create a partially applied version of the apply_sequence_parallelism function
|
||||
self.apply_sequence_parallelism = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree
|
||||
)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Apply sequence parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_sequence_parallelism(updated_kwargs)
|
||||
)
|
||||
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
# Register both hooks
|
||||
for model in self.models:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AllGatherWithGrad(torch.autograd.Function):
|
||||
"""Custom autograd function for all-gather to preserve gradients."""
|
||||
|
||||
@@ -486,10 +486,6 @@ def get_dataset_wrapper(
|
||||
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
)
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
if (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
|
||||
@@ -262,7 +262,7 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
context_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
@@ -1179,24 +1179,39 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_liger_sequence_parallel(cls, data):
|
||||
def check_grpo_liger_context_parallel(cls, data):
|
||||
if (
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
and data.get("context_parallel_degree", 1) > 1
|
||||
):
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
raise ValueError("GRPO + CP + Liger not currently supported")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
if not self.flash_attention:
|
||||
def check_context_parallel_degree(self):
|
||||
if not self.context_parallel_degree:
|
||||
self.context_parallel_degree = 1
|
||||
elif self.context_parallel_degree > 1:
|
||||
import torch
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
if not world_size >= self.context_parallel_degree:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
f"World size ({world_size}) must be greater "
|
||||
f"than or equal to CP degree ({self.context_parallel_degree})"
|
||||
)
|
||||
if not world_size % self.context_parallel_degree == 0:
|
||||
raise ValueError(
|
||||
f"SP degree ({self.context_parallel_degree}) "
|
||||
f"must evenly divide world size ({world_size})"
|
||||
)
|
||||
|
||||
if not (self.flash_attention or self.sdp_attention):
|
||||
raise ValueError(
|
||||
"flash_attention: true or sdp_attention: true "
|
||||
"must be set with context_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1205,21 +1220,22 @@ class AxolotlInputConfig(
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
if self.flash_attention:
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"context_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# TODO: monkeypatch / callback to average losses correctly across CP ranks
|
||||
# / fix gradient scaling across CP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Context parallelism (SP) is enabled with "
|
||||
f"context_parallel_degree={self.context_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
@@ -1230,7 +1246,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
if getattr(self, "context_parallel_degree", 1) == 1:
|
||||
return self
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
|
||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
@@ -479,7 +479,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
data_loader_len * cfg.num_epochs * cfg.context_parallel_degree
|
||||
)
|
||||
)
|
||||
|
||||
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_degree
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"sequence_parallel_degree": 1,
|
||||
"context_parallel_degree": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for sequence parallelism"""
|
||||
"""E2E tests for context parallelism"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
class TestContextParallelism:
|
||||
"""Test case for training with context parallelism enabled"""
|
||||
|
||||
def _run_sequence_parallel_test(
|
||||
def _run_context_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
@@ -24,7 +24,7 @@ class TestSequenceParallelism:
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
"""Helper method to run context parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -66,7 +66,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +109,7 @@ class TestSequenceParallelism:
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
def test_context_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
@@ -118,8 +118,8 @@ class TestSequenceParallelism:
|
||||
ring_attn_func,
|
||||
threshold,
|
||||
):
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
"""Test context parallel training with different configurations"""
|
||||
self._run_context_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
|
||||
@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
"""Tests for context parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
@@ -54,8 +54,8 @@ def fixture_cfg():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
def context_parallel_batch():
|
||||
"""Create a test batch for context parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
context_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class TestRingAttention:
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
"""Tests for validating context parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
@@ -155,24 +155,24 @@ class TestConfigValidation:
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
# Default context_parallel_degree
|
||||
({}, {"context_parallel_degree": 1}, True, None),
|
||||
# Invalid: context_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
{"context_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
# Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
@@ -185,32 +185,32 @@ class TestConfigValidation:
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
@@ -222,10 +222,10 @@ class TestConfigValidation:
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_validation(
|
||||
def test_context_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
"""Test various context parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
@@ -261,7 +261,7 @@ class TestConfigValidation:
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
@@ -281,7 +281,7 @@ class TestConfigValidation:
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
@@ -294,8 +294,8 @@ class TestConfigValidation:
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
class TestApplyContextParallelism:
|
||||
"""Tests for the apply_context_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
@@ -324,12 +324,12 @@ class TestApplySequenceParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=context_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -337,17 +337,17 @@ class TestApplySequenceParallelism:
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
assert result == context_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
batch = context_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
@@ -366,15 +366,15 @@ class TestApplySequenceParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
batch = context_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
@@ -386,14 +386,14 @@ class TestApplySequenceParallelism:
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# def test_batch_zigzag(self, context_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = sequence_parallel_batch
|
||||
# batch = context_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# result_rank0 = apply_context_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
@@ -401,7 +401,7 @@ class TestApplySequenceParallelism:
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# result_rank1 = apply_context_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
@@ -430,17 +430,17 @@ class TestApplySequenceParallelism:
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
self, mock_get_ring_attn_group, context_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
batch = context_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
apply_context_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -457,16 +457,14 @@ class TestApplySequenceParallelism:
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
def test_missing_position_ids(self, context_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
"""Test cases for tokenizer loading."""
|
||||
"""
|
||||
Test cases for the tokenizer loading
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,7 +13,9 @@ from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
class TestTokenizers:
|
||||
"""Test class for the load_tokenizer fn"""
|
||||
"""
|
||||
test class for the load_tokenizer fn
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_default_use_fast(self):
|
||||
@@ -149,50 +155,6 @@ class TestTokenizers:
|
||||
):
|
||||
load_tokenizer(cfg)
|
||||
|
||||
def test_mistral_tokenizer_auto_detection(self):
|
||||
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
||||
|
||||
def test_mixtral_tokenizer_auto_detection(self):
|
||||
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "model-hub/Mixtral-8x7B-v0.1",
|
||||
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
||||
|
||||
def test_mistral_tokenizer_basic_functionality(self):
|
||||
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Test basic encoding
|
||||
text = "Hello, world!"
|
||||
tokens = tokenizer.encode(text)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
# Test basic decoding
|
||||
decoded = tokenizer.decode(tokens)
|
||||
assert isinstance(decoded, str)
|
||||
|
||||
# Test token properties are accessible
|
||||
assert hasattr(tokenizer, "eos_token_id")
|
||||
assert hasattr(tokenizer, "bos_token_id")
|
||||
assert isinstance(tokenizer.eos_token_id, int)
|
||||
assert isinstance(tokenizer.bos_token_id, int)
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user