Distributed/ND-Parallel (#2977)

This commit is contained in:
salman
2025-07-31 20:25:02 +01:00
committed by GitHub
parent 7b68dfafd7
commit 294c7fe7a6
49 changed files with 712 additions and 835 deletions

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -20,7 +20,7 @@ min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
sequence_parallel_degree: 8
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -13,9 +13,9 @@ packaging==23.2
huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.54.0
transformers==4.54.1
tokenizers>=0.21.1
accelerate==1.9.0
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
datasets==4.0.0
deepspeed>=0.17.0
trl==0.20.0

View File

@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -69,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
context_parallel_size=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,

View File

@@ -24,9 +24,11 @@ from pathlib import Path
from typing import Any
import torch
from accelerate import PartialState
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
@@ -434,8 +436,30 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:

View File

@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

View File

@@ -27,6 +27,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -50,6 +51,7 @@ class AxolotlTrainer(
RngLoaderMixin,
CheckpointSaveMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""

View File

@@ -8,7 +8,11 @@ import torch
from torch import nn
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DPOTrainer,
DistributedParallelMixin,
):
"""Extend the base DPOTrainer for axolotl helpers."""

View File

@@ -82,14 +82,14 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.context_parallel_size > 1:
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
if trl.importance_sampling_level is not None:
grpo_args_kwargs["importance_sampling_level"] = (
trl.importance_sampling_level
)
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None
context_parallel_size: int | None = None

View File

@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
- Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
`context_parallel_size = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
context_parallel_size: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
context_parallel_size: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
self.context_parallel_size = context_parallel_size
self.num_sp_groups = world_size // context_parallel_size
self.sp_group_id = rank // context_parallel_size
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)

View File

@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
@@ -53,7 +57,12 @@ if is_peft_available():
class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
GRPOTrainer,
):
"""Extend the base GRPOTrainer for axolotl helpers"""
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
num_sp_groups = num_processes // self.args.context_parallel_size
# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
// self.args.context_parallel_size,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
context_parallel_size=self.args.context_parallel_size,
shuffle=True,
seed=self.args.seed,
drop_last=True,
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
return dataloader
# Otherwise prepare with accelerator
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate sequence parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
context_parallel_size = self.args.context_parallel_size
num_sp_groups = world_size // context_parallel_size
# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
group_leader_rank = sp_group_id * context_parallel_size
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
:: self.num_generations * self.args.context_parallel_size
]
with profiling_context(self, "vLLM.generate"):
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
len(all_prompts_text) // self.args.context_parallel_size
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size

View File

@@ -5,6 +5,7 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""

View File

@@ -5,6 +5,7 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin

View File

@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
def _save_optimizer_and_scheduler(self, output_dir):
try:
super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc:
LOG.warning(
except (NotImplementedError, KeyError) as exc:
# TODO: fix fsdp2 optimizer saving
LOG.warning_once(
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible."
"for this training run will not be possible.",
main_process_only=True,
)

View File

@@ -0,0 +1,20 @@
"""
Mixin for correctly saving fsdp
"""
from transformers import Trainer
class DistributedParallelMixin(Trainer):
"""
Mixin for correctly saving fsdp
"""
def _save(self, output_dir: str | None = None, state_dict=None):
if (
state_dict is None
and self.accelerator.parallelism_config
and self.accelerator.parallelism_config.dp_shard_enabled
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)

View File

@@ -8,13 +8,18 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
ORPOTrainer,
):
"""
Extend the base ORPOTrainer for axolotl helpers
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
KTOTrainer,
):
"""
Extend the base KTOTrainer for axolotl helpers
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
CPOTrainer,
):
"""
Extend the base CPOTrainer for axolotl helpers
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
RewardTrainer,
):
"""
Extend the base RewardTrainer for axolotl helpers
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
PRMTrainer,
):
"""
Extend the base trl.PRMTrainer for axolotl helpers

View File

@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
# pylint: disable=too-many-ancestors
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)

View File

@@ -16,8 +16,6 @@
Module for handling LIGER input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
Input args for LIGER.
"""
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
@model_validator(mode="before")
@classmethod
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_rms_norm_tensor_parallel(cls, data):
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
raise ValueError(
"`liger_rms_norm` is incompatible with tensor parallelism, "
"see https://github.com/linkedin/Liger-Kernel/issues/826"
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
raise ValueError("Tensor parallelism is not compatible with liger losses.")
return self

View File

@@ -13,7 +13,8 @@ import peft
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from accelerate import PartialState, init_empty_weights
from accelerate.parallelism_config import ParallelismConfig
from peft import (
PeftConfig,
PeftMixedModel,
@@ -48,10 +49,7 @@ from axolotl.loaders.utils import (
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
get_device_count,
get_device_type,
)
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -87,6 +85,9 @@ class ModelLoader:
`AutoModelForCausalLM`).
"""
use_parallel_config: bool | None = False
parallelism_config: ParallelismConfig | None = None
def __init__(
self,
cfg: DictDefault,
@@ -183,6 +184,20 @@ class ModelLoader:
def _apply_pre_model_load_setup(self):
"""Apply patches and setup configurations before model loading."""
if self.use_parallel_config is not None:
self.use_parallel_config = (
self.cfg.fsdp_config
or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)
or (
self.cfg.context_parallel_size
and self.cfg.context_parallel_size > 1
)
)
if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:
self.use_parallel_config = False
if self.use_parallel_config:
self._set_parallel_config()
self._set_auto_model_loader()
self._set_device_map_config()
if self.cfg.revision_of_model:
@@ -390,6 +405,86 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
@staticmethod
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
get_world_size(),
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
self.cfg.dp_shard_size,
self.cfg.dp_replicate_size,
bool(self.cfg.fsdp or self.cfg.fsdp_config),
)
if pc_kwargs:
self.parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = self.parallelism_config.build_device_mesh("cuda")
partial_state = PartialState()
# fmt: off
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
self.parallelism_config
)
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
device_mesh
)
# fmt: on
def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
@@ -622,6 +717,14 @@ class ModelLoader:
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = PartialState().device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
@@ -734,6 +837,14 @@ class ModelLoader:
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
# pylint: disable=protected-access
if self.cfg.tensor_parallel_size > 1:
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
# TODO(wing): remove once 4.54.1 is released
if self.model._tp_size != self.cfg.tensor_parallel_size:
self.model._tp_size = self.cfg.tensor_parallel_size
self.model._device_mesh = self.model_kwargs["device_mesh"]
return skip_move_to_device
def _set_z3_leaf_modules(self):

View File

@@ -49,6 +49,7 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
@@ -64,13 +65,19 @@ class PatchManager:
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
patch_prepare_from_posids()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -253,17 +260,6 @@ class PatchManager:
has_remote_code=has_remote_code,
)
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (

View File

@@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
)
mesh = getattr(accelerator.state, "device_mesh", None)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": (
mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]
if mesh is not None
else None
),
}
model_has_params4bit = False
for _, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from

View File

@@ -5,18 +5,14 @@
from .patch import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
register_ring_attn_from_device_mesh,
set_ring_attn_group,
update_ring_attn_params,
)
__all__ = (
"get_ring_attn_group",
"patch_prepare_data_loader",
"patch_prepare_device_mesh",
"register_ring_attn",
"register_ring_attn_from_device_mesh",
"set_ring_attn_group",
"update_ring_attn_params",
)

View File

@@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader
sequence parallelism training.
"""
import inspect
import os
from typing import Optional
import accelerate
import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
@@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size"""
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
def get_ring_attn_group() -> dist.ProcessGroup:
"""Getter for ring attention group on this rank."""
if RING_ATTN_GROUP is None:
raise RuntimeError("register_ring_attn() not yet called")
raise RuntimeError("register_ring_attn_from_device_mesh() not yet called")
return RING_ATTN_GROUP
@@ -161,15 +134,17 @@ def create_ring_flash_attention_forward(
]
def register_ring_attn(
sequence_parallel_degree: int,
def register_ring_attn_from_device_mesh(
device_mesh: "DeviceMesh",
context_parallel_dim: tuple[str, ...],
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""Create ring attention group and substitute flash attn with ring flash attn.
"""Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
device_mesh: DeviceMesh object containing the parallelism topology.
context_parallel_dim: Name of the sequence parallel dimension in the device mesh.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
@@ -177,44 +152,39 @@ def register_ring_attn(
`batch` function.
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
LOG.info(
f"Enabling ring attention sequence parallelism using DeviceMesh "
f"dimension '{context_parallel_dim}'",
main_process_only=True,
)
# Extract the sequence parallel submesh
try:
sequence_mesh = device_mesh[context_parallel_dim]
except (KeyError, IndexError) as e:
raise ValueError(
f"Dimension '{context_parallel_dim}' not found in device_mesh. "
f"Available dimensions: {device_mesh.mesh_dim_names}"
) from e
# Get the process group for context parallelism
sequence_pg = sequence_mesh.get_group()
context_parallel_size = sequence_mesh.size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
f"Sequence parallel degree: {context_parallel_size}, "
f"mesh shape: {sequence_mesh.mesh.shape}"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Log which ranks are in the current process group
if sequence_pg != dist.GroupMember.WORLD:
ranks_in_group = dist.get_process_group_ranks(sequence_pg)
LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}")
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
# Set the ring attention group
set_ring_attn_group(sequence_pg)
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
# fmt: off
@@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raises:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
original_source = inspect.getsource(original_fn)
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
raise RuntimeError(
"SP patch failed - target snippet not found. "
"Check accelerate's version or update the patch."
)
patched_source = original_source.replace(
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
"""
def _prepare_device_mesh(self):
"""Prepare the device mesh for distributed training. The dataloader will
determine how to load data based on the device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
if (
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
and hasattr(self.state, "ds_device_mesh")
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
)
device_ids = list(range(world_size))
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
)
# Replace the original method with our new method
# pylint: disable=protected-access
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
)

View File

@@ -0,0 +1,87 @@
"""
Monkey patch to fix transformers.modeling_flash_attention_utils.
see https://github.com/huggingface/transformers/pull/39653/files
"""
import sys
import torch
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
def patch_prepare_from_posids():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
_prepare_from_posids
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_prepare_from_posids",
_prepare_from_posids,
)

View File

@@ -205,7 +205,7 @@ def execute_training(
)
)
if cfg.sequence_parallel_degree > 1:
if cfg.context_parallel_size > 1:
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
@@ -213,7 +213,7 @@ def execute_training(
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
context_parallel_size=cfg.context_parallel_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,

View File

@@ -57,10 +57,10 @@ def gpu_memory_usage(device=0):
@check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
return active, allocated, reserved
def mps_memory_usage_all():
@@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0):
return 0.0
def log_gpu_memory_usage(
log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
def get_gpu_memory_usage(device: int | torch.device = 0):
cur_device_type = str(get_device_type())
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in cur_device_type and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
elif "gpu" in cur_device_type and torch.cuda.is_available():
elif "cuda" in cur_device_type and torch.cuda.is_available():
usage, cache, misc = gpu_memory_usage_all(device)
else:
return 0.0, 0.0, 0.0
return usage, cache, misc
def log_gpu_memory_usage(
log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
try:
active, allocated, reserved = get_gpu_memory_usage(device)
except ValueError:
# likely CPU, ignore
return
cur_device_type = str(get_device_type())
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
msg = f"{cur_device_type} memory usage:" if not msg else msg
log.info(
f"{msg} {usage:.03f}GB ({', '.join(extras)})",
if allocated > 0:
extras.append(f"+{allocated:.03f}GB allocated")
if reserved > 0:
extras.append(f"+{reserved:.03f}GB reserved")
msg = f"{cur_device_type} memory active:" if not msg else msg
log.debug(
f"{msg} {active:.03f}GB ({', '.join(extras)})",
stacklevel=2,
)

View File

@@ -35,7 +35,7 @@ from transformers.trainer_utils import (
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.distributed import (
barrier,
@@ -100,7 +100,6 @@ class GPUStatsCallback(
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
@@ -109,9 +108,21 @@ class GPUStatsCallback(
control: TrainerControl,
**kwargs,
) -> TrainerControl:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
if state.global_step > 0:
if self.cfg.use_wandb and state.is_world_process_zero:
try:
active, allocated, reserved = get_gpu_memory_usage()
wandb.log(
{
"memory/max_memory_active": active,
"memory/max_memory_allocated": allocated,
"memory/device_memory_reserved": reserved,
},
step=state.global_step,
)
except ValueError:
pass
log_gpu_memory_usage(LOG, "", self.cfg.device)
return control

View File

@@ -5,6 +5,7 @@ import inspect
import torch
import torch.distributed as dist
from accelerate import PartialState
from torch import nn
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
@@ -12,7 +13,7 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
register_ring_attn_from_device_mesh,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -150,9 +151,18 @@ def apply_sequence_parallelism(
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
local_valid_tokens = (batch["labels"] != -100).sum()
# All-reduce across sequence parallel ranks to get global token count
cp_group = get_ring_attn_group()
global_valid_tokens = local_valid_tokens.clone()
# we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)
global_valid_tokens = int(global_valid_tokens.item())
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
global_valid_tokens * gradient_accumulation_steps
)
return batch, original_seq_len, pad_len
@@ -167,7 +177,7 @@ class SequenceParallelContextManager:
Args:
models: List of models to apply sequence parallelism to pre- and post- forward
hooks.
sequence_parallel_degree: Number of processes to split sequences over.
context_parallel_size: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
@@ -179,14 +189,14 @@ class SequenceParallelContextManager:
def __init__(
self,
models: list[nn.Module],
sequence_parallel_degree: int,
context_parallel_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
gather_outputs: bool,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.context_parallel_size = context_parallel_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
@@ -230,8 +240,10 @@ class SequenceParallelContextManager:
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree,
partial_state = PartialState()
register_ring_attn_from_device_mesh(
device_mesh=partial_state.device_mesh,
context_parallel_dim=("cp",),
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)

View File

@@ -430,10 +430,11 @@ def save_preprocessed_dataset(
num_shards=cfg.num_dataset_shards_to_save,
)
else:
min_rows_per_proc = 256
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=min(max(1, len(dataset) // 8), num_workers),
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)

View File

@@ -2,12 +2,15 @@
utils to get GPU info for the current environment
"""
from importlib.metadata import version
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import (
get_gpu_info,
)
from packaging.version import Version, parse
def check_cuda_p2p_ib_support():
@@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support():
except Exception: # pylint: disable=broad-except # nosec
pass
return True
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)
def is_package_version_ge(package: str, version_: str) -> bool:
package_version = get_package_version(package)
return package_version >= parse(version_)

View File

@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc
import math
import time
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
from typing import Iterable, Iterator, Union
@@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler):
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
# log timer for generating batches
start_time = time.time()
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
LOG.debug(f"generate_batches time: {time.time() - start_time}")
len_batches = min(_sampled_lens)
# Gather minimum across all ranks

View File

@@ -651,7 +651,23 @@ class AxolotlInputConfig(
},
)
dp_shard_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of devices to shard across. If not set, will use all available devices."
},
)
dp_replicate_size: int | None = Field(
default=None,
json_schema_extra={"description": "Number of devices to replicate across."},
)
sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Deprecated: use `context_parallel_size` instead"
},
)
context_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."

View File

@@ -673,7 +673,7 @@ class RLValidationMixin:
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1
and data.get("context_parallel_size", 1) > 1
):
raise ValueError("GRPO + SP + Liger not currently supported")
return data
@@ -880,51 +880,35 @@ class OptimizationValidationMixin:
return self
@model_validator(mode="after")
def check_fsdp_sharded_state_dict_w_safetensors(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and hasattr(self, "save_safetensors")
and self.save_safetensors
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
and str(getattr(self, "fsdp_version", "1")) != "2"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return self
@model_validator(mode="before")
@classmethod
def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"):
raise ValueError(
"Tensor parallelism (TP) is only supported with DeepSpeed"
)
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
should_save = False
if "tensor_parallel" not in ds_config:
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
should_save = True
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
if data.get("deepspeed"):
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
should_save = False
if "tensor_parallel" not in ds_config:
ds_config["tensor_parallel"] = {
"autotp_size": tensor_parallel_size
}
should_save = True
if (
"gather_16bit_weights_on_model_save"
] = True
should_save = True
if should_save:
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save"
] = True
should_save = True
if should_save:
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data
@@ -1205,13 +1189,18 @@ class ComplexValidationMixin:
return self
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
def check_context_parallel_size(self):
if self.sequence_parallel_degree and not self.context_parallel_size:
LOG.warning(
"`sequence_parallel_degree` is deprecated, use `context_parallel_size`"
)
self.context_parallel_size = self.sequence_parallel_degree
if not self.context_parallel_size:
self.context_parallel_size = 1
elif self.context_parallel_size > 1:
if not self.flash_attention:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
"flash_attention: true must be set with context_parallel_size > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
@@ -1221,17 +1210,23 @@ class ComplexValidationMixin:
)
try:
import transformers.modeling_flash_attention_utils
# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
@@ -1242,7 +1237,7 @@ class ComplexValidationMixin:
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
if getattr(self, "context_parallel_size", 1) == 1:
return self
if self.ring_attn_func is not None:
@@ -1259,6 +1254,20 @@ class ComplexValidationMixin:
return self
class DistributedValidationMixin:
"""validation for distributed training."""
@model_validator(mode="after")
def check_tensor_parallel_optimizer(self):
if self.tensor_parallel_size > 1:
if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
raise ValueError(
"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
)
return self
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,

View File

@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
LOG.debug(
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.floor(
data_loader_len
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
)
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_size
* cfg.tensor_parallel_size
/ cfg.batch_size
)

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"context_parallel_size": 1,
"tensor_parallel_size": 1,
# Dtype
"fp16": False,

View File

@@ -67,7 +67,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"ring_attn_func": ring_attn_func,
"save_first_step": False,
}
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(

View File

@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {

View File

@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
"""Test class for FP8 mixed precision with FSDP2 functionality."""
@require_torch_2_7_0
@require_hopper
def test_fp8_fsdp2_smoke(self, temp_dir):
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
cfg = DictDefault(

View File

@@ -0,0 +1,69 @@
"""multigpu e2e test for tensor parallelism."""
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
class TestTensorParallel:
"""Test class for Tensor Parallel functionality."""
@pytest.mark.skip(
reason="TP doesn't work with models with tied weights (embeddings)"
)
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"tensor_parallel_size": 2,
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -1,481 +0,0 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@pytest.fixture
def partial_state():
"""Create a real PartialState instance for testing."""
state = PartialState()
return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"save_first_step": False,
}
)
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
labels = input_ids.clone()
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"labels": labels,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_register_ring_attn(
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
mock_group = MagicMock()
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Verify that new_group was called
mock_new_group.assert_called()
# Clean up
set_ring_attn_group(None)
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
"valid_grpo",
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Checks for both ranks
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
# # Rank 0 should get chunks [0, 1] and [6, 7]
# # Rank 1 should get chunks [2, 3] and [4, 5]
# if seq_len == 8:
# # Create expected tensors for comparison
# rank0_expected = torch.cat(
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
# )
# rank1_expected = torch.cat(
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
# )
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result, _, _ = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" in result
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -52,6 +52,8 @@ class TestLoadModelUtils:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"tensor_parallel_size": 1,
"context_parallel_size": 1,
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init

View File

@@ -142,6 +142,10 @@ def is_hopper():
return compute_capability == (9, 0)
def require_hopper(test_case):
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None:

View File

@@ -171,3 +171,44 @@ class TestModelsUtils:
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)
@pytest.mark.parametrize(
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
[
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
],
)
def test_get_parallel_config_kwargs(
self,
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
expected,
):
res = (
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
)
if expected[0] > 1:
assert res["tp_size"] == expected[0]
if expected[1] > 1:
assert res["cp_size"] == expected[1]
if expected[2] > 1:
assert res["dp_shard_size"] == expected[2]
if expected[3] > 1:
assert res["dp_replicate_size"] == expected[3]

View File

@@ -26,32 +26,6 @@ class TestFSDPValidation:
assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
# test w/o prefix too
cfg = min_base_cfg | DictDefault(
fsdp_config={
"state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={