Distributed/ND-Parallel (#2977)

This commit is contained in:
salman
2025-07-31 20:25:02 +01:00
committed by GitHub
parent 7b68dfafd7
commit 294c7fe7a6
49 changed files with 712 additions and 835 deletions

View File

@@ -2,7 +2,7 @@
set -e set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \ pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str): def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess. # Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml ```yaml
# Set to a divisor (> 1) of the number of GPUs available # Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func: ring_attn_func:
``` ```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8 - With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4 - With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
... ...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size ## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) - Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases - The number of batches processed per step decreases
For example: For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) - With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -20,7 +20,7 @@ min_sample_len: 200_000
sample_packing: true sample_packing: true
tiled_mlp: true tiled_mlp: true
sequence_parallel_degree: 8 context_parallel_size: 8
plugins: plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -13,9 +13,9 @@ packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft==0.16.0 peft==0.16.0
transformers==4.54.0 transformers==4.54.1
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.9.0 accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
datasets==4.0.0 datasets==4.0.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.20.0 trl==0.20.0

View File

@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
else: else:
_install_requires.append("xformers==0.0.31") _install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3") _install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126 # since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126") _dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"] extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:

View File

@@ -69,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
sequence_parallel_degree=None, context_parallel_size=None,
deepspeed=None, deepspeed=None,
fsdp=None, fsdp=None,
fsdp_config=None, fsdp_config=None,

View File

@@ -24,9 +24,11 @@ from pathlib import Path
from typing import Any from typing import Any
import torch import torch
from accelerate import PartialState
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
) )
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
@@ -434,8 +436,30 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict): def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict): def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True: if self.cfg.activation_offloading is True:

View File

@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO: if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

View File

@@ -27,6 +27,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import ( from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin, ActivationOffloadingMixin,
CheckpointSaveMixin, CheckpointSaveMixin,
DistributedParallelMixin,
OptimizerMixin, OptimizerMixin,
PackingMixin, PackingMixin,
RngLoaderMixin, RngLoaderMixin,
@@ -50,6 +51,7 @@ class AxolotlTrainer(
RngLoaderMixin, RngLoaderMixin,
CheckpointSaveMixin, CheckpointSaveMixin,
ActivationOffloadingMixin, ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer, Trainer,
): ):
"""Extend the base Trainer for axolotl helpers""" """Extend the base Trainer for axolotl helpers"""

View File

@@ -8,7 +8,11 @@ import torch
from torch import nn from torch import nn
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
class AxolotlDPOTrainer( class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DPOTrainer,
DistributedParallelMixin,
): ):
"""Extend the base DPOTrainer for axolotl helpers.""" """Extend the base DPOTrainer for axolotl helpers."""

View File

@@ -82,14 +82,14 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.context_parallel_size > 1:
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
if trl.importance_sampling_level is not None: if trl.importance_sampling_level is not None:
grpo_args_kwargs["importance_sampling_level"] = ( grpo_args_kwargs["importance_sampling_level"] = (
trl.importance_sampling_level trl.importance_sampling_level
) )
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if trl.reward_weights: if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training""" """Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None context_parallel_size: int | None = None

View File

@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
- Data is properly distributed across SP groups. - Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 `context_parallel_size = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs. SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups Sequence Parallel Groups
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process. rank: Rank of current process.
batch_size: Number of samples per batch. batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process. repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group. context_parallel_size: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset. shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling. seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch. drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int, rank: int,
batch_size: int = 1, batch_size: int = 1,
repeat_count: int = 1, repeat_count: int = 1,
sequence_parallel_degree: int = 1, context_parallel_size: int = 1,
shuffle: bool = True, shuffle: bool = True,
seed: int = 0, seed: int = 0,
drop_last: bool = False, drop_last: bool = False,
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.rank = rank self.rank = rank
# Sequence parallelism parameters # Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree self.context_parallel_size = context_parallel_size
self.num_sp_groups = world_size // sequence_parallel_degree self.num_sp_groups = world_size // context_parallel_size
self.sp_group_id = rank // sequence_parallel_degree self.sp_group_id = rank // context_parallel_size
# Adjust dataset size for distributed sampling # Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset) self.num_samples = len(self.dataset)

View File

@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group from axolotl.monkeypatch.ring_attn import get_ring_attn_group
@@ -53,7 +57,12 @@ if is_peft_available():
class AxolotlGRPOTrainer( class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
GRPOTrainer,
): ):
"""Extend the base GRPOTrainer for axolotl helpers""" """Extend the base GRPOTrainer for axolotl helpers"""
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Get number of SP groups (number of processes divided by SP degree) # Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree num_sp_groups = num_processes // self.args.context_parallel_size
# Calculate batch size per SP group (not per process) # Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values: if self.num_generations not in possible_values:
raise ValueError( raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " f"With sequence parallelism (degree {self.args.context_parallel_size}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt " f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, " f"({self.num_generations}). Given the current eval batch size, "
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
rank=self.rank, rank=self.rank,
batch_size=effective_batch_size batch_size=effective_batch_size
// self.num_generations // self.num_generations
// self.args.sequence_parallel_degree, // self.args.context_parallel_size,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree, context_parallel_size=self.args.context_parallel_size,
shuffle=True, shuffle=True,
seed=self.args.seed, seed=self.args.seed,
drop_last=True, drop_last=True,
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension). # slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
return dataloader return dataloader
# Otherwise prepare with accelerator # Otherwise prepare with accelerator
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text) all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
# Calculate sequence parallel group information # Calculate sequence parallel group information
world_size = self.accelerator.num_processes world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree context_parallel_size = self.args.context_parallel_size
num_sp_groups = world_size // sequence_parallel_degree num_sp_groups = world_size // context_parallel_size
# Since processes in the same SP group have the same prompts, we need to ensure # Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group # we only take one copy of each prompt from each SP group
ordered_set_of_prompts = [] ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups): for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader) # Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree group_leader_rank = sp_group_id * context_parallel_size
# Extract prompts from this SP group, accounting for num_generations duplicates # Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group # We only need prompts from one rank in each SP group
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate # num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually. # prompt individually.
ordered_set_of_prompts = all_prompts_text[ ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree :: self.num_generations * self.args.context_parallel_size
] ]
with profiling_context(self, "vLLM.generate"): with profiling_context(self, "vLLM.generate"):
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
) )
else: else:
completion_ids = [None] * ( completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree len(all_prompts_text) // self.args.context_parallel_size
) )
# Broadcast the completions from the main process to all processes # Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0) completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism # Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to) # Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size sp_group_id = self.accelerator.process_index // self.local_world_size
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4) advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data # Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to) # Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size sp_group_id = self.accelerator.process_index // self.local_world_size

View File

@@ -5,6 +5,7 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation""" """Mamba specific trainer to handle loss calculation"""

View File

@@ -5,6 +5,7 @@
from .activation_checkpointing import ActivationOffloadingMixin from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin from .checkpoints import CheckpointSaveMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin from .optimizer import OptimizerMixin
from .packing import PackingMixin from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin from .rng_state_loader import RngLoaderMixin

View File

@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
def _save_optimizer_and_scheduler(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir):
try: try:
super()._save_optimizer_and_scheduler(output_dir) super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc: except (NotImplementedError, KeyError) as exc:
LOG.warning( # TODO: fix fsdp2 optimizer saving
LOG.warning_once(
f"Trainer does not support saving optimizer and scheduler: {exc}\n" f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints " "Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible." "for this training run will not be possible.",
main_process_only=True,
) )

View File

@@ -0,0 +1,20 @@
"""
Mixin for correctly saving fsdp
"""
from transformers import Trainer
class DistributedParallelMixin(Trainer):
"""
Mixin for correctly saving fsdp
"""
def _save(self, output_dir: str | None = None, state_dict=None):
if (
state_dict is None
and self.accelerator.parallelism_config
and self.accelerator.parallelism_config.dp_shard_enabled
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)

View File

@@ -8,13 +8,18 @@ from trl import (
RewardTrainer, RewardTrainer,
) )
from axolotl.core.trainers.mixins import RngLoaderMixin from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class AxolotlORPOTrainer( class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
ORPOTrainer,
): ):
""" """
Extend the base ORPOTrainer for axolotl helpers Extend the base ORPOTrainer for axolotl helpers
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
class AxolotlKTOTrainer( class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
KTOTrainer,
): ):
""" """
Extend the base KTOTrainer for axolotl helpers Extend the base KTOTrainer for axolotl helpers
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
class AxolotlCPOTrainer( class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
CPOTrainer,
): ):
""" """
Extend the base CPOTrainer for axolotl helpers Extend the base CPOTrainer for axolotl helpers
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
class AxolotlRewardTrainer( class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
RewardTrainer,
): ):
""" """
Extend the base RewardTrainer for axolotl helpers Extend the base RewardTrainer for axolotl helpers
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
class AxolotlPRMTrainer( class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
PRMTrainer,
): ):
""" """
Extend the base trl.PRMTrainer for axolotl helpers Extend the base trl.PRMTrainer for axolotl helpers

View File

@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
# pylint: disable=too-many-ancestors
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
""" """
Custom trainer subclass for Knowledge Distillation (KD) Custom trainer subclass for Knowledge Distillation (KD)

View File

@@ -16,8 +16,6 @@
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
Input args for LIGER. Input args for LIGER.
""" """
liger_rope: Optional[bool] = None liger_rope: bool | None = None
liger_rms_norm: Optional[bool] = None liger_rms_norm: bool | None = None
liger_layer_norm: Optional[bool] = None liger_layer_norm: bool | None = None
liger_swiglu: Optional[bool] = None liger_swiglu: bool | None = None
liger_glu_activation: Optional[bool] = None liger_glu_activation: bool | None = None
liger_cross_entropy: Optional[bool] = None liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: Optional[bool] = None liger_fused_linear_cross_entropy: bool | None = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`." "You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_liger_rms_norm_tensor_parallel(cls, data):
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
raise ValueError(
"`liger_rms_norm` is incompatible with tensor parallelism, "
"see https://github.com/linkedin/Liger-Kernel/issues/826"
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
raise ValueError("Tensor parallelism is not compatible with liger losses.")
return self

View File

@@ -13,7 +13,8 @@ import peft
import torch import torch
import transformers import transformers
import transformers.modeling_utils import transformers.modeling_utils
from accelerate import init_empty_weights from accelerate import PartialState, init_empty_weights
from accelerate.parallelism_config import ParallelismConfig
from peft import ( from peft import (
PeftConfig, PeftConfig,
PeftMixedModel, PeftMixedModel,
@@ -48,10 +49,7 @@ from axolotl.loaders.utils import (
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import ( from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
get_device_count,
get_device_type,
)
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
@@ -87,6 +85,9 @@ class ModelLoader:
`AutoModelForCausalLM`). `AutoModelForCausalLM`).
""" """
use_parallel_config: bool | None = False
parallelism_config: ParallelismConfig | None = None
def __init__( def __init__(
self, self,
cfg: DictDefault, cfg: DictDefault,
@@ -183,6 +184,20 @@ class ModelLoader:
def _apply_pre_model_load_setup(self): def _apply_pre_model_load_setup(self):
"""Apply patches and setup configurations before model loading.""" """Apply patches and setup configurations before model loading."""
if self.use_parallel_config is not None:
self.use_parallel_config = (
self.cfg.fsdp_config
or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)
or (
self.cfg.context_parallel_size
and self.cfg.context_parallel_size > 1
)
)
if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:
self.use_parallel_config = False
if self.use_parallel_config:
self._set_parallel_config()
self._set_auto_model_loader() self._set_auto_model_loader()
self._set_device_map_config() self._set_device_map_config()
if self.cfg.revision_of_model: if self.cfg.revision_of_model:
@@ -390,6 +405,86 @@ class ModelLoader:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@staticmethod
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
get_world_size(),
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
self.cfg.dp_shard_size,
self.cfg.dp_replicate_size,
bool(self.cfg.fsdp or self.cfg.fsdp_config),
)
if pc_kwargs:
self.parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = self.parallelism_config.build_device_mesh("cuda")
partial_state = PartialState()
# fmt: off
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
self.parallelism_config
)
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
device_mesh
)
# fmt: on
def _set_auto_model_loader(self): def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader` (set at `__init__`). When using a multimodal model, `self.auto_model_loader`
@@ -622,6 +717,14 @@ class ModelLoader:
def _build_model(self) -> bool: def _build_model(self) -> bool:
"""Load model, with load strategy depending on config.""" """Load model, with load strategy depending on config."""
skip_move_to_device = False skip_move_to_device = False
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = PartialState().device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading: if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True skip_move_to_device = True
@@ -734,6 +837,14 @@ class ModelLoader:
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
skip_move_to_device = True skip_move_to_device = True
# pylint: disable=protected-access
if self.cfg.tensor_parallel_size > 1:
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
# TODO(wing): remove once 4.54.1 is released
if self.model._tp_size != self.cfg.tensor_parallel_size:
self.model._tp_size = self.cfg.tensor_parallel_size
self.model._device_mesh = self.model_kwargs["device_mesh"]
return skip_move_to_device return skip_move_to_device
def _set_z3_leaf_modules(self): def _set_z3_leaf_modules(self):

View File

@@ -49,6 +49,7 @@ class PatchManager:
def apply_pre_model_load_patches(self): def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config.""" """Apply pre-model load patches based on config."""
self._apply_transformers_patches()
# self._apply_flex_attention_patches() # self._apply_flex_attention_patches()
self._apply_flash_attention_patches() self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch() self._apply_chunked_cross_entropy_patch()
@@ -64,13 +65,19 @@ class PatchManager:
self._patch_llama_derived_model() self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch() self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch() self._apply_self_attention_lora_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self): def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config.""" """Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type) self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches() self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
patch_prepare_from_posids()
def apply_post_model_load_patches(self, model: PreTrainedModel): def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance.""" """Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model) self._apply_llama_flash_attn_patches(model)
@@ -253,17 +260,6 @@ class PatchManager:
has_remote_code=has_remote_code, has_remote_code=has_remote_code,
) )
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _apply_tiled_mlp(self, model_type: str): def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp: if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import ( from axolotl.monkeypatch.tiled_mlp import (

View File

@@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy, auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
) )
mesh = getattr(accelerator.state, "device_mesh", None)
fsdp2_kwargs = { fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward, "reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload, "offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": (
mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]
if mesh is not None
else None
),
} }
model_has_params4bit = False model_has_params4bit = False
for _, param in model.named_parameters(): for _, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from # this is a temporary fix whereby loading models with bnb params cannot be moved from

View File

@@ -5,18 +5,14 @@
from .patch import ( from .patch import (
get_ring_attn_group, get_ring_attn_group,
patch_prepare_data_loader, register_ring_attn_from_device_mesh,
patch_prepare_device_mesh,
register_ring_attn,
set_ring_attn_group, set_ring_attn_group,
update_ring_attn_params, update_ring_attn_params,
) )
__all__ = ( __all__ = (
"get_ring_attn_group", "get_ring_attn_group",
"patch_prepare_data_loader", "register_ring_attn_from_device_mesh",
"patch_prepare_device_mesh",
"register_ring_attn",
"set_ring_attn_group", "set_ring_attn_group",
"update_ring_attn_params", "update_ring_attn_params",
) )

View File

@@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader
sequence parallelism training. sequence parallelism training.
""" """
import inspect
import os import os
from typing import Optional from typing import Optional
import accelerate
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import DeviceMesh
try: try:
from transformers.modeling_flash_attention_utils import _flash_supports_window from transformers.modeling_flash_attention_utils import _flash_supports_window
@@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__) LOG = get_logger(__name__)
RING_ATTN_GROUP = None RING_ATTN_GROUP = None
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size"""
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
def get_ring_attn_group() -> dist.ProcessGroup: def get_ring_attn_group() -> dist.ProcessGroup:
"""Getter for ring attention group on this rank.""" """Getter for ring attention group on this rank."""
if RING_ATTN_GROUP is None: if RING_ATTN_GROUP is None:
raise RuntimeError("register_ring_attn() not yet called") raise RuntimeError("register_ring_attn_from_device_mesh() not yet called")
return RING_ATTN_GROUP return RING_ATTN_GROUP
@@ -161,15 +134,17 @@ def create_ring_flash_attention_forward(
] ]
def register_ring_attn( def register_ring_attn_from_device_mesh(
sequence_parallel_degree: int, device_mesh: "DeviceMesh",
context_parallel_dim: tuple[str, ...],
heads_k_stride: int | None, heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None, ring_attn_func: RingAttnFunc | None,
): ):
"""Create ring attention group and substitute flash attn with ring flash attn. """Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
Args: Args:
sequence_parallel_degree: Sequence parallelism factor. device_mesh: DeviceMesh object containing the parallelism topology.
context_parallel_dim: Name of the sequence parallel dimension in the device mesh.
heads_k_stride: Sequence parallelism K head stride size. Passed through to heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation. `varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
@@ -177,44 +152,39 @@ def register_ring_attn(
`batch` function. `batch` function.
""" """
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size()
LOG.info(
f"Enabling ring attention sequence parallelism using DeviceMesh "
f"dimension '{context_parallel_dim}'",
main_process_only=True,
)
# Extract the sequence parallel submesh
try:
sequence_mesh = device_mesh[context_parallel_dim]
except (KeyError, IndexError) as e:
raise ValueError(
f"Dimension '{context_parallel_dim}' not found in device_mesh. "
f"Available dimensions: {device_mesh.mesh_dim_names}"
) from e
# Get the process group for context parallelism
sequence_pg = sequence_mesh.get_group()
context_parallel_size = sequence_mesh.size()
if rank == 0: if rank == 0:
LOG.info( LOG.info(
"Enabling ring attention sequence parallelism: " f"Sequence parallel degree: {context_parallel_size}, "
f"each sequence will be processed across {sequence_parallel_degree} GPUs" f"mesh shape: {sequence_mesh.mesh.shape}"
) )
assert sequence_parallel_degree <= world_size, ( # Log which ranks are in the current process group
f"sequence_parallel_degree ({sequence_parallel_degree}) " if sequence_pg != dist.GroupMember.WORLD:
f"must be less than or equal to world_size ({world_size})" ranks_in_group = dist.get_process_group_ranks(sequence_pg)
) LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}")
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Assign ranks to sequence parallel groups # Set the ring attention group
group_assignments = {} set_ring_attn_group(sequence_pg)
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
# fmt: off # fmt: off
@@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raises:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
original_source = inspect.getsource(original_fn)
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
raise RuntimeError(
"SP patch failed - target snippet not found. "
"Check accelerate's version or update the patch."
)
patched_source = original_source.replace(
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
"""
def _prepare_device_mesh(self):
"""Prepare the device mesh for distributed training. The dataloader will
determine how to load data based on the device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
if (
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
and hasattr(self.state, "ds_device_mesh")
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
)
device_ids = list(range(world_size))
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
)
# Replace the original method with our new method
# pylint: disable=protected-access
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
)

View File

@@ -0,0 +1,87 @@
"""
Monkey patch to fix transformers.modeling_flash_attention_utils.
see https://github.com/huggingface/transformers/pull/39653/files
"""
import sys
import torch
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
def patch_prepare_from_posids():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
_prepare_from_posids
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_prepare_from_posids",
_prepare_from_posids,
)

View File

@@ -205,7 +205,7 @@ def execute_training(
) )
) )
if cfg.sequence_parallel_degree > 1: if cfg.context_parallel_size > 1:
models = [trainer.model] models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model: if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model) models.append(trainer.ref_model)
@@ -213,7 +213,7 @@ def execute_training(
stack.enter_context( stack.enter_context(
SequenceParallelContextManager( SequenceParallelContextManager(
models=models, models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree, context_parallel_size=cfg.context_parallel_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps, gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func, ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride, heads_k_stride=cfg.heads_k_stride,

View File

@@ -57,10 +57,10 @@ def gpu_memory_usage(device=0):
@check_cuda_device((0.0, 0.0, 0.0)) @check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0): def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3 active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3 allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
smi = gpu_memory_usage_smi(device) reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
return usage, reserved - usage, max(0, smi - reserved) return active, allocated, reserved
def mps_memory_usage_all(): def mps_memory_usage_all():
@@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0):
return 0.0 return 0.0
def log_gpu_memory_usage( def get_gpu_memory_usage(device: int | torch.device = 0):
log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
cur_device_type = str(get_device_type()) cur_device_type = str(get_device_type())
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all() usage, cache, misc = mps_memory_usage_all()
elif "npu" in cur_device_type and is_torch_npu_available(): elif "npu" in cur_device_type and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device) usage, cache, misc = npu_memory_usage_all(device)
elif "gpu" in cur_device_type and torch.cuda.is_available(): elif "cuda" in cur_device_type and torch.cuda.is_available():
usage, cache, misc = gpu_memory_usage_all(device) usage, cache, misc = gpu_memory_usage_all(device)
else: else:
return 0.0, 0.0, 0.0
return usage, cache, misc
def log_gpu_memory_usage(
log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
try:
active, allocated, reserved = get_gpu_memory_usage(device)
except ValueError:
# likely CPU, ignore
return return
cur_device_type = str(get_device_type())
extras = [] extras = []
if cache > 0: if allocated > 0:
extras.append(f"+{cache:.03f}GB cache") extras.append(f"+{allocated:.03f}GB allocated")
if misc > 0: if reserved > 0:
extras.append(f"+{misc:.03f}GB misc") extras.append(f"+{reserved:.03f}GB reserved")
msg = f"{cur_device_type} memory usage:" if not msg else msg msg = f"{cur_device_type} memory active:" if not msg else msg
log.info( log.debug(
f"{msg} {usage:.03f}GB ({', '.join(extras)})", f"{msg} {active:.03f}GB ({', '.join(extras)})",
stacklevel=2, stacklevel=2,
) )

View File

@@ -35,7 +35,7 @@ from transformers.trainer_utils import (
from trl.models import unwrap_model_for_generation from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
@@ -100,7 +100,6 @@ class GPUStatsCallback(
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self.logged = False
def on_step_end( def on_step_end(
self, self,
@@ -109,9 +108,21 @@ class GPUStatsCallback(
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
) -> TrainerControl: ) -> TrainerControl:
if not self.logged and state.global_step > 1: if state.global_step > 0:
log_gpu_memory_usage(LOG, "while training", self.cfg.device) if self.cfg.use_wandb and state.is_world_process_zero:
self.logged = True try:
active, allocated, reserved = get_gpu_memory_usage()
wandb.log(
{
"memory/max_memory_active": active,
"memory/max_memory_allocated": allocated,
"memory/device_memory_reserved": reserved,
},
step=state.global_step,
)
except ValueError:
pass
log_gpu_memory_usage(LOG, "", self.cfg.device)
return control return control

View File

@@ -5,6 +5,7 @@ import inspect
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import PartialState
from torch import nn from torch import nn
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
@@ -12,7 +13,7 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import ( from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group, get_ring_attn_group,
register_ring_attn, register_ring_attn_from_device_mesh,
update_ring_attn_params, update_ring_attn_params,
) )
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -150,9 +151,18 @@ def apply_sequence_parallelism(
if "num_items_in_batch" in batch: if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across # Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis. # all samples in a gradient accumulated batch, not on a per-step basis.
local_valid_tokens = (batch["labels"] != -100).sum()
# All-reduce across sequence parallel ranks to get global token count
cp_group = get_ring_attn_group()
global_valid_tokens = local_valid_tokens.clone()
# we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)
global_valid_tokens = int(global_valid_tokens.item())
batch["num_items_in_batch"] = ( batch["num_items_in_batch"] = (
batch["labels"] != -100 global_valid_tokens * gradient_accumulation_steps
).sum() * gradient_accumulation_steps )
return batch, original_seq_len, pad_len return batch, original_seq_len, pad_len
@@ -167,7 +177,7 @@ class SequenceParallelContextManager:
Args: Args:
models: List of models to apply sequence parallelism to pre- and post- forward models: List of models to apply sequence parallelism to pre- and post- forward
hooks. hooks.
sequence_parallel_degree: Number of processes to split sequences over. context_parallel_size: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over. gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused. ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to heads_k_stride: Sequence parallelism K head stride size. Passed through to
@@ -179,14 +189,14 @@ class SequenceParallelContextManager:
def __init__( def __init__(
self, self,
models: list[nn.Module], models: list[nn.Module],
sequence_parallel_degree: int, context_parallel_size: int,
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, ring_attn_func: RingAttnFunc,
heads_k_stride: int | None, heads_k_stride: int | None,
gather_outputs: bool, gather_outputs: bool,
): ):
self.models = models self.models = models
self.sequence_parallel_degree = sequence_parallel_degree self.context_parallel_size = context_parallel_size
self.gradient_accumulation_steps = gradient_accumulation_steps self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride self.heads_k_stride = heads_k_stride
@@ -230,8 +240,10 @@ class SequenceParallelContextManager:
def _register_ring_attn(self): def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism # Initialize ring attn for sequence parallelism
register_ring_attn( partial_state = PartialState()
sequence_parallel_degree=self.sequence_parallel_degree, register_ring_attn_from_device_mesh(
device_mesh=partial_state.device_mesh,
context_parallel_dim=("cp",),
heads_k_stride=self.heads_k_stride, heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )

View File

@@ -430,10 +430,11 @@ def save_preprocessed_dataset(
num_shards=cfg.num_dataset_shards_to_save, num_shards=cfg.num_dataset_shards_to_save,
) )
else: else:
min_rows_per_proc = 256
os.makedirs(prepared_ds_path, exist_ok=True) os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk( dataset.save_to_disk(
str(prepared_ds_path), str(prepared_ds_path),
num_proc=min(max(1, len(dataset) // 8), num_workers), num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
max_shard_size=None, max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save, num_shards=cfg.num_dataset_shards_to_save,
) )

View File

@@ -2,12 +2,15 @@
utils to get GPU info for the current environment utils to get GPU info for the current environment
""" """
from importlib.metadata import version
from accelerate.utils.environment import ( from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
) )
from accelerate.utils.environment import ( from accelerate.utils.environment import (
get_gpu_info, get_gpu_info,
) )
from packaging.version import Version, parse
def check_cuda_p2p_ib_support(): def check_cuda_p2p_ib_support():
@@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support():
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
return True return True
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)
def is_package_version_ge(package: str, version_: str) -> bool:
package_version = get_package_version(package)
return package_version >= parse(version_)

View File

@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc import gc
import math import math
import time
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context from multiprocessing import cpu_count, get_context
from typing import Iterable, Iterator, Union from typing import Iterable, Iterator, Union
@@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler):
_sampled_lens = [] _sampled_lens = []
for _ in range(self.num_count_samples): for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches self._batches = None # Reset cached batches
# log timer for generating batches
start_time = time.time()
_sampled_lens.append(len(self.generate_batches(set_stats=False))) _sampled_lens.append(len(self.generate_batches(set_stats=False)))
LOG.debug(f"generate_batches time: {time.time() - start_time}")
len_batches = min(_sampled_lens) len_batches = min(_sampled_lens)
# Gather minimum across all ranks # Gather minimum across all ranks

View File

@@ -651,7 +651,23 @@ class AxolotlInputConfig(
}, },
) )
dp_shard_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of devices to shard across. If not set, will use all available devices."
},
)
dp_replicate_size: int | None = Field(
default=None,
json_schema_extra={"description": "Number of devices to replicate across."},
)
sequence_parallel_degree: int | None = Field( sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Deprecated: use `context_parallel_size` instead"
},
)
context_parallel_size: int | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details." "description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."

View File

@@ -673,7 +673,7 @@ class RLValidationMixin:
data.get("rl") == "grpo" data.get("rl") == "grpo"
and data.get("trl", {}) and data.get("trl", {})
and data.get("trl").get("use_liger_loss") and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1 and data.get("context_parallel_size", 1) > 1
): ):
raise ValueError("GRPO + SP + Liger not currently supported") raise ValueError("GRPO + SP + Liger not currently supported")
return data return data
@@ -880,51 +880,35 @@ class OptimizationValidationMixin:
return self return self
@model_validator(mode="after")
def check_fsdp_sharded_state_dict_w_safetensors(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and hasattr(self, "save_safetensors")
and self.save_safetensors
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
and str(getattr(self, "fsdp_version", "1")) != "2"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_tensor_parallel_size_update_ds_json(cls, data): def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size") tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1: if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"): if data.get("deepspeed"):
raise ValueError( with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
"Tensor parallelism (TP) is only supported with DeepSpeed" ds_config = json.load(ds_fin)
) should_save = False
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: if "tensor_parallel" not in ds_config:
ds_config = json.load(ds_fin) ds_config["tensor_parallel"] = {
should_save = False "autotp_size": tensor_parallel_size
if "tensor_parallel" not in ds_config: }
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size} should_save = True
should_save = True if (
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save" "gather_16bit_weights_on_model_save"
] = True not in ds_config["zero_optimization"]
should_save = True ):
if should_save: ds_config["zero_optimization"][
temp_dir = tempfile.mkdtemp() "gather_16bit_weights_on_model_save"
with open( ] = True
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" should_save = True
) as ds_fout: if should_save:
json.dump(ds_config, ds_fout, indent=4) temp_dir = tempfile.mkdtemp()
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data return data
@@ -1205,13 +1189,18 @@ class ComplexValidationMixin:
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def check_sequence_parallel_degree(self): def check_context_parallel_size(self):
if not self.sequence_parallel_degree: if self.sequence_parallel_degree and not self.context_parallel_size:
self.sequence_parallel_degree = 1 LOG.warning(
elif self.sequence_parallel_degree > 1: "`sequence_parallel_degree` is deprecated, use `context_parallel_size`"
)
self.context_parallel_size = self.sequence_parallel_degree
if not self.context_parallel_size:
self.context_parallel_size = 1
elif self.context_parallel_size > 1:
if not self.flash_attention: if not self.flash_attention:
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with context_parallel_size > 1"
) )
if self.sample_packing and self.micro_batch_size > 1: if self.sample_packing and self.micro_batch_size > 1:
@@ -1221,17 +1210,23 @@ class ComplexValidationMixin:
) )
try: try:
import transformers.modeling_flash_attention_utils
# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception: except ImportError as exception:
raise ImportError( raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. " "context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] " "Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`." "or `pip install ring-flash-attn>=0.1.4`."
) from exception ) from exception
LOG.warning( LOG.warning(
"Sequence parallelism (SP) is enabled with " "Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. " f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP " "Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. " "losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
@@ -1242,7 +1237,7 @@ class ComplexValidationMixin:
@model_validator(mode="after") @model_validator(mode="after")
def validate_ring_attn_func(self): def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1: if getattr(self, "context_parallel_size", 1) == 1:
return self return self
if self.ring_attn_func is not None: if self.ring_attn_func is not None:
@@ -1259,6 +1254,20 @@ class ComplexValidationMixin:
return self return self
class DistributedValidationMixin:
"""validation for distributed training."""
@model_validator(mode="after")
def check_tensor_parallel_optimizer(self):
if self.tensor_parallel_size > 1:
if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
raise ValueError(
"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
)
return self
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
class ValidationMixin( class ValidationMixin(
DatasetValidationMixin, DatasetValidationMixin,

View File

@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1 - 1
) )
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.context_parallel_size
* cfg.tensor_parallel_size * cfg.tensor_parallel_size
) )
LOG.debug( LOG.debug(
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.floor( math.floor(
data_loader_len data_loader_len
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.context_parallel_size
* cfg.tensor_parallel_size * cfg.tensor_parallel_size
) )
) )
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil( math.ceil(
len(train_dataset) len(train_dataset)
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.context_parallel_size
* cfg.tensor_parallel_size * cfg.tensor_parallel_size
/ cfg.batch_size / cfg.batch_size
) )

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2, "dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1, "context_parallel_size": 1,
"tensor_parallel_size": 1, "tensor_parallel_size": 1,
# Dtype # Dtype
"fp16": False, "fp16": False,

View File

@@ -67,7 +67,7 @@ class TestSequenceParallelism:
"logging_steps": 1, "logging_steps": 1,
"weight_decay": 0.0, "weight_decay": 0.0,
"use_tensorboard": True, "use_tensorboard": True,
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"ring_attn_func": ring_attn_func, "ring_attn_func": ring_attn_func,
"save_first_step": False, "save_first_step": False,
} }
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func (True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func (False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5), # (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func # (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
], ],
ids=[ ids=[
"sample_packing, varlen_llama3 ring_attn_func", "sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func", "no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func", # "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", # "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
], ],
) )
def test_sequence_parallel_training( def test_sequence_parallel_training(

View File

@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"special_tokens": { "special_tokens": {

View File

@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
"""Test class for FP8 mixed precision with FSDP2 functionality.""" """Test class for FP8 mixed precision with FSDP2 functionality."""
@require_torch_2_7_0 @require_torch_2_7_0
@require_hopper
def test_fp8_fsdp2_smoke(self, temp_dir): def test_fp8_fsdp2_smoke(self, temp_dir):
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training""" """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
cfg = DictDefault( cfg = DictDefault(

View File

@@ -0,0 +1,69 @@
"""multigpu e2e test for tensor parallelism."""
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
class TestTensorParallel:
"""Test class for Tensor Parallel functionality."""
@pytest.mark.skip(
reason="TP doesn't work with models with tied weights (embeddings)"
)
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"tensor_parallel_size": 2,
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -1,481 +0,0 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@pytest.fixture
def partial_state():
"""Create a real PartialState instance for testing."""
state = PartialState()
return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"save_first_step": False,
}
)
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
labels = input_ids.clone()
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"labels": labels,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_register_ring_attn(
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
mock_group = MagicMock()
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Verify that new_group was called
mock_new_group.assert_called()
# Clean up
set_ring_attn_group(None)
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
"valid_grpo",
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Checks for both ranks
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
# # Rank 0 should get chunks [0, 1] and [6, 7]
# # Rank 1 should get chunks [2, 3] and [4, 5]
# if seq_len == 8:
# # Create expected tensors for comparison
# rank0_expected = torch.cat(
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
# )
# rank1_expected = torch.cat(
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
# )
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result, _, _ = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" in result
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -52,6 +52,8 @@ class TestLoadModelUtils:
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"tensor_parallel_size": 1,
"context_parallel_size": 1,
} }
) )
self.model_loader = ( # pylint: disable=attribute-defined-outside-init self.model_loader = ( # pylint: disable=attribute-defined-outside-init

View File

@@ -142,6 +142,10 @@ def is_hopper():
return compute_capability == (9, 0) return compute_capability == (9, 0)
def require_hopper(test_case):
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
def check_tensorboard( def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None: ) -> None:

View File

@@ -171,3 +171,44 @@ class TestModelsUtils:
message_property_mappings={"content": "different_content"}, message_property_mappings={"content": "different_content"},
) )
assert "Conflicting message content fields" in str(exc_info.value) assert "Conflicting message content fields" in str(exc_info.value)
@pytest.mark.parametrize(
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
[
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
],
)
def test_get_parallel_config_kwargs(
self,
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
expected,
):
res = (
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
)
if expected[0] > 1:
assert res["tp_size"] == expected[0]
if expected[1] > 1:
assert res["cp_size"] == expected[1]
if expected[2] > 1:
assert res["dp_shard_size"] == expected[2]
if expected[3] > 1:
assert res["dp_replicate_size"] == expected[3]

View File

@@ -26,32 +26,6 @@ class TestFSDPValidation:
assert cfg.fsdp_version == 2 assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None assert cfg.fsdp_config.fsdp_version is None
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
# test w/o prefix too
cfg = min_base_cfg | DictDefault(
fsdp_config={
"state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg): def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
cfg = min_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={