Compare commits

...

9 Commits

Author SHA1 Message Date
Salman Mohammadi
bc2bc688d8 update fsdp2 patch 2025-07-23 16:53:03 +01:00
Wing Lian
b3c04dd9fe workaround for fsdp2 optimizer save failures 2025-07-23 09:38:57 -04:00
Wing Lian
972c719d38 use latest transformers on main with fix 2025-07-23 09:22:36 -04:00
Wing Lian
2c1cb8b300 fix for accelerator state getting reset and missing schema 2025-07-23 08:43:34 -04:00
Wing Lian
cca207eec4 handle none checks 2025-07-22 21:21:45 -04:00
Wing Lian
9a2da4d9f0 update tp validation 2025-07-22 21:20:57 -04:00
Wing Lian
8fe4758e94 make sure to return data for validation 2025-07-22 21:18:39 -04:00
Wing Lian
8c641fdcb4 handle tp load 2025-07-22 21:17:27 -04:00
Wing Lian
5c74bebfd0 use new upstream branches for nd-parallelism 2025-07-22 21:12:22 -04:00
25 changed files with 212 additions and 122 deletions

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml ```yaml
# Set to a divisor (> 1) of the number of GPUs available # Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func: ring_attn_func:
``` ```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8 - With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4 - With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
... ...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size ## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) - Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases - The number of batches processed per step decreases
For example: For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) - With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -13,9 +13,9 @@ packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft==0.16.0 peft==0.16.0
transformers==4.53.2 transformers @ git+https://github.com/huggingface/transformers.git@82603b6cc284dbdf2b7a7cf070feb6a2c3bb53cf
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.9.0 accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config
datasets==4.0.0 datasets==4.0.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.19.1 trl==0.19.1

View File

@@ -70,7 +70,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
sequence_parallel_degree=None, context_parallel_size=None,
deepspeed=None, deepspeed=None,
fsdp=None, fsdp=None,
fsdp_config=None, fsdp_config=None,

View File

@@ -27,6 +27,7 @@ import torch
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
) )
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
@@ -434,8 +435,18 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict): def _configure_accelerator_config(self, training_args_kwargs: dict):
use_configured_state = True
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=True,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict): def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True: if self.cfg.activation_offloading is True:

View File

@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO: if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

View File

@@ -82,8 +82,8 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1: if cfg.context_parallel_size > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
if trl.reward_weights: if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training""" """Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None context_parallel_size: int | None = None

View File

@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
- Data is properly distributed across SP groups. - Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 `context_parallel_size = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs. SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups Sequence Parallel Groups
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process. rank: Rank of current process.
batch_size: Number of samples per batch. batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process. repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group. context_parallel_size: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset. shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling. seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch. drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int, rank: int,
batch_size: int = 1, batch_size: int = 1,
repeat_count: int = 1, repeat_count: int = 1,
sequence_parallel_degree: int = 1, context_parallel_size: int = 1,
shuffle: bool = True, shuffle: bool = True,
seed: int = 0, seed: int = 0,
drop_last: bool = False, drop_last: bool = False,
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.rank = rank self.rank = rank
# Sequence parallelism parameters # Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree self.context_parallel_size = context_parallel_size
self.num_sp_groups = world_size // sequence_parallel_degree self.num_sp_groups = world_size // context_parallel_size
self.sp_group_id = rank // sequence_parallel_degree self.sp_group_id = rank // context_parallel_size
# Adjust dataset size for distributed sampling # Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset) self.num_samples = len(self.dataset)

View File

@@ -100,7 +100,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Get number of SP groups (number of processes divided by SP degree) # Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree num_sp_groups = num_processes // self.args.context_parallel_size
# Calculate batch size per SP group (not per process) # Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
@@ -130,7 +130,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values: if self.num_generations not in possible_values:
raise ValueError( raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " f"With sequence parallelism (degree {self.args.context_parallel_size}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt " f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, " f"({self.num_generations}). Given the current eval batch size, "
@@ -167,9 +167,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
rank=self.rank, rank=self.rank,
batch_size=effective_batch_size batch_size=effective_batch_size
// self.num_generations // self.num_generations
// self.args.sequence_parallel_degree, // self.args.context_parallel_size,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree, context_parallel_size=self.args.context_parallel_size,
shuffle=True, shuffle=True,
seed=self.args.seed, seed=self.args.seed,
drop_last=True, drop_last=True,
@@ -235,7 +235,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension). # slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
return dataloader return dataloader
# Otherwise prepare with accelerator # Otherwise prepare with accelerator
@@ -308,18 +308,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text) all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
# Calculate sequence parallel group information # Calculate sequence parallel group information
world_size = self.accelerator.num_processes world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree context_parallel_size = self.args.context_parallel_size
num_sp_groups = world_size // sequence_parallel_degree num_sp_groups = world_size // context_parallel_size
# Since processes in the same SP group have the same prompts, we need to ensure # Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group # we only take one copy of each prompt from each SP group
ordered_set_of_prompts = [] ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups): for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader) # Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree group_leader_rank = sp_group_id * context_parallel_size
# Extract prompts from this SP group, accounting for num_generations duplicates # Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group # We only need prompts from one rank in each SP group
@@ -335,7 +335,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate # num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually. # prompt individually.
ordered_set_of_prompts = all_prompts_text[ ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree :: self.num_generations * self.args.context_parallel_size
] ]
with profiling_context(self, "vLLM.generate"): with profiling_context(self, "vLLM.generate"):
@@ -352,14 +352,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
) )
else: else:
completion_ids = [None] * ( completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree len(all_prompts_text) // self.args.context_parallel_size
) )
# Broadcast the completions from the main process to all processes # Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0) completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism # Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to) # Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size sp_group_id = self.accelerator.process_index // self.local_world_size
@@ -583,7 +583,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4) advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data # Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to) # Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size sp_group_id = self.accelerator.process_index // self.local_world_size

View File

@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
def _save_optimizer_and_scheduler(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir):
try: try:
super()._save_optimizer_and_scheduler(output_dir) super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc: except (NotImplementedError, KeyError) as exc:
LOG.warning( # TODO: fix fsdp2 optimizer saving
LOG.warning_once(
f"Trainer does not support saving optimizer and scheduler: {exc}\n" f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints " "Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible." "for this training run will not be possible.",
main_process_only=True,
) )

View File

@@ -16,8 +16,6 @@
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
Input args for LIGER. Input args for LIGER.
""" """
liger_rope: Optional[bool] = None liger_rope: bool | None = None
liger_rms_norm: Optional[bool] = None liger_rms_norm: bool | None = None
liger_layer_norm: Optional[bool] = None liger_layer_norm: bool | None = None
liger_swiglu: Optional[bool] = None liger_swiglu: bool | None = None
liger_glu_activation: Optional[bool] = None liger_glu_activation: bool | None = None
liger_cross_entropy: Optional[bool] = None liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: Optional[bool] = None liger_fused_linear_cross_entropy: bool | None = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -62,3 +60,13 @@ class LigerArgs(BaseModel):
"You cannot have both `liger_glu_activation` and `tiled_mlp` set." "You cannot have both `liger_glu_activation` and `tiled_mlp` set."
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_liger_rms_norm_tensor_parallel(cls, data):
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
raise ValueError(
"`liger_rms_norm` is incompatible with tensor parallelism, "
"see https://github.com/linkedin/Liger-Kernel/issues/826"
)
return data

View File

@@ -13,7 +13,8 @@ import peft
import torch import torch
import transformers import transformers
import transformers.modeling_utils import transformers.modeling_utils
from accelerate import init_empty_weights from accelerate import PartialState, init_empty_weights
from accelerate.utils.dataclasses import ParallelismConfig
from peft import ( from peft import (
PeftConfig, PeftConfig,
PeftMixedModel, PeftMixedModel,
@@ -51,6 +52,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
get_device_count, get_device_count,
get_device_type, get_device_type,
get_world_size,
) )
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model_quant
@@ -182,6 +184,7 @@ class ModelLoader:
def _apply_pre_model_load_setup(self): def _apply_pre_model_load_setup(self):
"""Apply patches and setup configurations before model loading.""" """Apply patches and setup configurations before model loading."""
self._set_parallel_config()
self._set_auto_model_loader() self._set_auto_model_loader()
self._set_device_map_config() self._set_device_map_config()
if self.cfg.revision_of_model: if self.cfg.revision_of_model:
@@ -389,6 +392,52 @@ class ModelLoader:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
dp_replicate_size = get_world_size()
pc_kwargs = {}
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
if dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
parallelism_config = ParallelismConfig(
**pc_kwargs,
)
mesh_dim_names, mesh_shape = parallelism_config.get_mesh()
device_mesh = torch.distributed.init_device_mesh(
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
)
submeshes = [
tuple(parallelism_config.dp_dim_names),
tuple(parallelism_config.dp_shard_cp_dim_names),
tuple(parallelism_config.dp_cp_dim_names),
]
submesh_names = [
# create a submesh which is only used for distributing data across data parallel dims (no comms)
"dp",
# create a submesh which is used *just* for FSDP parameter gathering/scattering
# and gradients reduce-scattering
"dp_shard_cp",
# create a submesh which is used for correctly reducing loss across data replica/context parallel
"dp_cp",
]
for submesh, submesh_name in zip(submeshes, submesh_names):
if submesh:
device_mesh[submesh]._flatten( # pylint: disable=protected-access
submesh_name
)
PartialState().parallelism_config = parallelism_config
PartialState().device_mesh = device_mesh
def _set_auto_model_loader(self): def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader` (set at `__init__`). When using a multimodal model, `self.auto_model_loader`
@@ -621,6 +670,14 @@ class ModelLoader:
def _build_model(self) -> bool: def _build_model(self) -> bool:
"""Load model, with load strategy depending on config.""" """Load model, with load strategy depending on config."""
skip_move_to_device = False skip_move_to_device = False
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = PartialState().device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading: if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True skip_move_to_device = True

View File

@@ -261,14 +261,14 @@ class PatchManager:
def _apply_sequence_parallel_patches(self): def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches.""" """Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.ring_attn.patch import ( from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader, patch_prepare_data_loader,
patch_prepare_device_mesh, patch_prepare_device_mesh,
) )
patch_prepare_data_loader() patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp) patch_prepare_device_mesh(self.cfg.context_parallel_size, self.cfg.fsdp)
def _apply_tiled_mlp(self, model_type: str): def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp: if self.cfg.tiled_mlp:

View File

@@ -254,6 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"offload_policy": fsdp2_plugin.cpu_offload, "offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
} }
model_has_params4bit = False model_has_params4bit = False

View File

@@ -162,14 +162,14 @@ def create_ring_flash_attention_forward(
def register_ring_attn( def register_ring_attn(
sequence_parallel_degree: int, context_parallel_size: int,
heads_k_stride: int | None, heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None, ring_attn_func: RingAttnFunc | None,
): ):
"""Create ring attention group and substitute flash attn with ring flash attn. """Create ring attention group and substitute flash attn with ring flash attn.
Args: Args:
sequence_parallel_degree: Sequence parallelism factor. context_parallel_size: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed through to heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation. `varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
@@ -182,25 +182,25 @@ def register_ring_attn(
if rank == 0: if rank == 0:
LOG.info( LOG.info(
"Enabling ring attention sequence parallelism: " "Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs" f"each sequence will be processed across {context_parallel_size} GPUs"
) )
assert sequence_parallel_degree <= world_size, ( assert context_parallel_size <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) " f"context_parallel_size ({context_parallel_size}) "
f"must be less than or equal to world_size ({world_size})" f"must be less than or equal to world_size ({world_size})"
) )
assert world_size % sequence_parallel_degree == 0, ( assert world_size % context_parallel_size == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) " f"context_parallel_size ({context_parallel_size}) "
f"must evenly divide world_size ({world_size})" f"must evenly divide world_size ({world_size})"
) )
# Assign ranks to sequence parallel groups # Assign ranks to sequence parallel groups
group_assignments = {} group_assignments = {}
for i in range(world_size // sequence_parallel_degree): for i in range(world_size // context_parallel_size):
ring_attn_ranks = list( ring_attn_ranks = list(
range( range(
i * sequence_parallel_degree, i * context_parallel_size,
(i + 1) * sequence_parallel_degree, (i + 1) * context_parallel_size,
) )
) )
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
@@ -299,12 +299,12 @@ def patch_prepare_data_loader():
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False): def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree. that includes sequence parallelism with the specified degree.
Args: Args:
sequence_parallel_degree: The degree of sequence parallelism to use. context_parallel_size: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP. fsdp: Whether to use FSDP.
""" """
@@ -323,8 +323,8 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
# Create device mesh with sequence parallelism # Create device mesh with sequence parallelism
world_size = dist.get_world_size() world_size = dist.get_world_size()
mesh_shape = ( mesh_shape = (
world_size // sequence_parallel_degree, world_size // context_parallel_size,
sequence_parallel_degree, context_parallel_size,
) )
device_ids = list(range(world_size)) device_ids = list(range(world_size))
@@ -344,5 +344,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
LOG.info( LOG.info(
"Successfully patched Accelerator._prepare_device_mesh " "Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}" f"with context_parallel_size={context_parallel_size}"
) )

View File

@@ -202,7 +202,7 @@ def execute_training(
) )
) )
if cfg.sequence_parallel_degree > 1: if cfg.context_parallel_size > 1:
models = [trainer.model] models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model: if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model) models.append(trainer.ref_model)
@@ -210,7 +210,7 @@ def execute_training(
stack.enter_context( stack.enter_context(
SequenceParallelContextManager( SequenceParallelContextManager(
models=models, models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree, context_parallel_size=cfg.context_parallel_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps, gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func, ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride, heads_k_stride=cfg.heads_k_stride,

View File

@@ -167,7 +167,7 @@ class SequenceParallelContextManager:
Args: Args:
models: List of models to apply sequence parallelism to pre- and post- forward models: List of models to apply sequence parallelism to pre- and post- forward
hooks. hooks.
sequence_parallel_degree: Number of processes to split sequences over. context_parallel_size: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over. gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused. ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to heads_k_stride: Sequence parallelism K head stride size. Passed through to
@@ -179,14 +179,14 @@ class SequenceParallelContextManager:
def __init__( def __init__(
self, self,
models: list[nn.Module], models: list[nn.Module],
sequence_parallel_degree: int, context_parallel_size: int,
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, ring_attn_func: RingAttnFunc,
heads_k_stride: int | None, heads_k_stride: int | None,
gather_outputs: bool, gather_outputs: bool,
): ):
self.models = models self.models = models
self.sequence_parallel_degree = sequence_parallel_degree self.context_parallel_size = context_parallel_size
self.gradient_accumulation_steps = gradient_accumulation_steps self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride self.heads_k_stride = heads_k_stride
@@ -231,7 +231,7 @@ class SequenceParallelContextManager:
def _register_ring_attn(self): def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism # Initialize ring attn for sequence parallelism
register_ring_attn( register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree, context_parallel_size=self.context_parallel_size,
heads_k_stride=self.heads_k_stride, heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )

View File

@@ -644,7 +644,19 @@ class AxolotlInputConfig(
}, },
) )
dp_shard_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of devices to shard across. If not set, will use all available devices."
},
)
sequence_parallel_degree: int | None = Field( sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Deprecated: use `context_parallel_size` instead"
},
)
context_parallel_size: int | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details." "description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."

View File

View File

@@ -686,7 +686,7 @@ class RLValidationMixin:
data.get("rl") == "grpo" data.get("rl") == "grpo"
and data.get("trl", {}) and data.get("trl", {})
and data.get("trl").get("use_liger_loss") and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1 and data.get("context_parallel_size", 1) > 1
): ):
raise ValueError("GRPO + SP + Liger not currently supported") raise ValueError("GRPO + SP + Liger not currently supported")
return data return data
@@ -913,31 +913,30 @@ class OptimizationValidationMixin:
def check_tensor_parallel_size_update_ds_json(cls, data): def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size") tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1: if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"): if data.get("deepspeed"):
raise ValueError( with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
"Tensor parallelism (TP) is only supported with DeepSpeed" ds_config = json.load(ds_fin)
) should_save = False
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: if "tensor_parallel" not in ds_config:
ds_config = json.load(ds_fin) ds_config["tensor_parallel"] = {
should_save = False "autotp_size": tensor_parallel_size
if "tensor_parallel" not in ds_config: }
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size} should_save = True
should_save = True if (
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save" "gather_16bit_weights_on_model_save"
] = True not in ds_config["zero_optimization"]
should_save = True ):
if should_save: ds_config["zero_optimization"][
temp_dir = tempfile.mkdtemp() "gather_16bit_weights_on_model_save"
with open( ] = True
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" should_save = True
) as ds_fout: if should_save:
json.dump(ds_config, ds_fout, indent=4) temp_dir = tempfile.mkdtemp()
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data return data
@@ -1235,13 +1234,13 @@ class ComplexValidationMixin:
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def check_sequence_parallel_degree(self): def check_context_parallel_size(self):
if not self.sequence_parallel_degree: if not self.context_parallel_size:
self.sequence_parallel_degree = 1 self.context_parallel_size = 1
elif self.sequence_parallel_degree > 1: elif self.context_parallel_size > 1:
if not self.flash_attention: if not self.flash_attention:
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with context_parallel_size > 1"
) )
if self.sample_packing and self.micro_batch_size > 1: if self.sample_packing and self.micro_batch_size > 1:
@@ -1254,14 +1253,14 @@ class ComplexValidationMixin:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception: except ImportError as exception:
raise ImportError( raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. " "context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] " "Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`." "or `pip install ring-flash-attn>=0.1.4`."
) from exception ) from exception
LOG.warning( LOG.warning(
"Sequence parallelism (SP) is enabled with " "Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. " f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP " "Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. " "losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
@@ -1272,7 +1271,7 @@ class ComplexValidationMixin:
@model_validator(mode="after") @model_validator(mode="after")
def validate_ring_attn_func(self): def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1: if getattr(self, "context_parallel_size", 1) == 1:
return self return self
if self.ring_attn_func is not None: if self.ring_attn_func is not None:

View File

@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1 - 1
) )
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.context_parallel_size
* cfg.tensor_parallel_size * cfg.tensor_parallel_size
) )
LOG.debug( LOG.debug(
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.floor( math.floor(
data_loader_len data_loader_len
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.context_parallel_size
* cfg.tensor_parallel_size * cfg.tensor_parallel_size
) )
) )
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil( math.ceil(
len(train_dataset) len(train_dataset)
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.context_parallel_size
* cfg.tensor_parallel_size * cfg.tensor_parallel_size
/ cfg.batch_size / cfg.batch_size
) )

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2, "dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1, "context_parallel_size": 1,
"tensor_parallel_size": 1, "tensor_parallel_size": 1,
# Dtype # Dtype
"fp16": False, "fp16": False,

View File

@@ -67,7 +67,7 @@ class TestSequenceParallelism:
"logging_steps": 1, "logging_steps": 1,
"weight_decay": 0.0, "weight_decay": 0.0,
"use_tensorboard": True, "use_tensorboard": True,
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"ring_attn_func": ring_attn_func, "ring_attn_func": ring_attn_func,
"save_first_step": False, "save_first_step": False,
} }

View File

@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"special_tokens": { "special_tokens": {

View File

@@ -111,7 +111,7 @@ class TestRingAttention:
# Call register_ring_attn with size 4 # Call register_ring_attn with size 4
register_ring_attn( register_ring_attn(
sequence_parallel_degree=4, context_parallel_size=4,
heads_k_stride=1, heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3, ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
) )
@@ -156,24 +156,24 @@ class TestConfigValidation:
[ [
# Valid configuration # Valid configuration
( (
{"sequence_parallel_degree": 2, "flash_attention": True}, {"context_parallel_size": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True}, {"context_parallel_size": 2, "flash_attention": True},
True, True,
None, None,
), ),
# Default sequence_parallel_degree # Default context_parallel_size
({}, {"sequence_parallel_degree": 1}, True, None), ({}, {"context_parallel_size": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention # Invalid: context_parallel_size > 1 without flash_attention
( (
{"sequence_parallel_degree": 2, "flash_attention": False}, {"context_parallel_size": 2, "flash_attention": False},
None, None,
False, False,
"flash_attention: true must be set", "flash_attention: true must be set",
), ),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 # Invalid: context_parallel_size > 1 with sample_packing and micro_batch_size > 1
( (
{ {
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"micro_batch_size": 2, "micro_batch_size": 2,
@@ -186,13 +186,13 @@ class TestConfigValidation:
# Valid: Basic GRPO config # Valid: Basic GRPO config
( (
{ {
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"micro_batch_size": 2, "micro_batch_size": 2,
"trl": {"use_liger_loss": True}, "trl": {"use_liger_loss": True},
}, },
{ {
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"micro_batch_size": 2, "micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True), "trl": TRLConfig(use_liger_loss=True),
@@ -204,7 +204,7 @@ class TestConfigValidation:
( (
{ {
"rl": "grpo", "rl": "grpo",
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"micro_batch_size": 2, "micro_batch_size": 2,
"trl": {"use_liger_loss": True}, "trl": {"use_liger_loss": True},
@@ -262,7 +262,7 @@ class TestConfigValidation:
# Apply updates to base config # Apply updates to base config
cfg = base_cfg | { cfg = base_cfg | {
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"sample_packing": sample_packing, "sample_packing": sample_packing,
} }
@@ -282,7 +282,7 @@ class TestConfigValidation:
# Invalid configuration with invalid ring_attn_func # Invalid configuration with invalid ring_attn_func
cfg = base_cfg | { cfg = base_cfg | {
"sequence_parallel_degree": 2, "context_parallel_size": 2,
"flash_attention": True, "flash_attention": True,
"ring_attn_func": "INVALID_FUNC", "ring_attn_func": "INVALID_FUNC",
} }