diff --git a/cicd/multigpu.sh b/cicd/multigpu.sh index 4fd5672be..3ec4456b9 100755 --- a/cicd/multigpu.sh +++ b/cicd/multigpu.sh @@ -2,7 +2,7 @@ set -e # Only run two tests at a time to avoid OOM on GPU (with coverage collection) -pytest -v -n2 \ +pytest -v --durations=10 -n2 \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \ /workspace/axolotl/tests/e2e/multigpu/ \ diff --git a/cicd/single_gpu.py b/cicd/single_gpu.py index 6955af013..eb34e1748 100644 --- a/cicd/single_gpu.py +++ b/cicd/single_gpu.py @@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}" def run_cmd(cmd: str, run_folder: str): import subprocess # nosec + sp_env = os.environ.copy() + sp_env["AXOLOTL_DATASET_PROCESSES"] = "8" + # Propagate errors from subprocess. - if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec + if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec exit(exit_code) # pylint: disable=consider-using-sys-exit diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index b98206135..d1933a145 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file: ```yaml # Set to a divisor (> 1) of the number of GPUs available -sequence_parallel_degree: 4 # Split sequences across 4 GPUs +context_parallel_size: 4 # Split sequences across 4 GPUs # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -30,7 +30,7 @@ heads_k_stride: 1 ring_attn_func: ``` -The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: +The `context_parallel_size` should be a divisor of the total number of GPUs. For example: - With 8 GPUs, valid values would be 2, 4, or 8 - With 4 GPUs, valid values would be 2 or 4 @@ -66,7 +66,7 @@ sequence_len: 8192 ... -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU +context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality. ## Effect on Batch Size -When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: +When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because: -- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) +- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence) - The number of batches processed per step decreases For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step -- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) +- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 diff --git a/examples/alst/llama3-8b-deepspeed-alst.yaml b/examples/alst/llama3-8b-deepspeed-alst.yaml index dc82fa3be..dea23c5ee 100644 --- a/examples/alst/llama3-8b-deepspeed-alst.yaml +++ b/examples/alst/llama3-8b-deepspeed-alst.yaml @@ -20,7 +20,7 @@ min_sample_len: 200_000 sample_packing: true tiled_mlp: true -sequence_parallel_degree: 8 +context_parallel_size: 8 plugins: - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin diff --git a/requirements.txt b/requirements.txt index ae433193f..4fc662a87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,9 @@ packaging==23.2 huggingface_hub>=0.33.0 peft==0.16.0 -transformers==4.54.0 +transformers==4.54.1 tokenizers>=0.21.1 -accelerate==1.9.0 +accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152 datasets==4.0.0 deepspeed>=0.17.0 trl==0.20.0 diff --git a/setup.py b/setup.py index 6576c44e5..de6f19e56 100644 --- a/setup.py +++ b/setup.py @@ -72,12 +72,13 @@ def parse_requirements(extras_require_map): extras_require_map.pop("vllm") else: _install_requires.append("xformers==0.0.31") + extras_require_map["vllm"] = ["vllm>=0.10.0"] elif (major, minor) >= (2, 6): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers==0.0.29.post3") # since we only support 2.6.0+cu126 _dependency_links.append("https://download.pytorch.org/whl/cu126") - extras_require_map["vllm"] = ["vllm==0.8.5.post1"] + extras_require_map.pop("vllm") elif (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 422593a48..31fad1b29 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -69,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: load_in_8bit=False, load_in_4bit=False, flash_attention=False, - sequence_parallel_degree=None, + context_parallel_size=None, deepspeed=None, fsdp=None, fsdp_config=None, diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 0a37d2766..32b228e21 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -24,9 +24,11 @@ from pathlib import Path from typing import Any import torch +from accelerate import PartialState from transformers import ( TrainerCallback, ) +from transformers.trainer_pt_utils import AcceleratorConfig from transformers.training_args import OptimizerNames from axolotl.integrations.base import PluginManager @@ -434,8 +436,30 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode def _configure_accelerator_config(self, training_args_kwargs: dict): + partial_state = PartialState() + has_pc_attr = ( + hasattr(partial_state, "parallelism_config") + and partial_state.parallelism_config + ) + has_pc_key = ( + "parallelism_config" + in partial_state._shared_state # pylint: disable=protected-access + and partial_state._shared_state[ # pylint: disable=protected-access + "parallelism_config" + ] + ) + use_configured_state = has_pc_attr or has_pc_key if self.cfg.accelerator_config: - training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config + use_configured_state = self.cfg.accelerator_config.pop( + "use_configured_state", use_configured_state + ) + training_args_kwargs["accelerator_config"] = AcceleratorConfig( + use_configured_state=use_configured_state, **self.cfg.accelerator_config + ) + else: + training_args_kwargs["accelerator_config"] = AcceleratorConfig( + use_configured_state=use_configured_state, + ) def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.activation_offloading is True: diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index e60b0e958..8cc6eeebf 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class( - sequence_parallel=self.cfg.sequence_parallel_degree > 1 + sequence_parallel=self.cfg.context_parallel_size > 1 ) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3dfaf47ce..e3818ca7c 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -27,6 +27,7 @@ from typing_extensions import override from axolotl.core.trainers.mixins import ( ActivationOffloadingMixin, CheckpointSaveMixin, + DistributedParallelMixin, OptimizerMixin, PackingMixin, RngLoaderMixin, @@ -50,6 +51,7 @@ class AxolotlTrainer( RngLoaderMixin, CheckpointSaveMixin, ActivationOffloadingMixin, + DistributedParallelMixin, Trainer, ): """Extend the base Trainer for axolotl helpers""" diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 762e0a331..b3067bb46 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -8,7 +8,11 @@ import torch from torch import nn from trl import DPOTrainer -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import ( + DistributedParallelMixin, + RngLoaderMixin, + SchedulerMixin, +) from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import ( class AxolotlDPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DPOTrainer, + DistributedParallelMixin, ): """Extend the base DPOTrainer for axolotl helpers.""" diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 5f8e4a8b3..839c20c2e 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -82,14 +82,14 @@ class GRPOStrategy: grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print + if cfg.context_parallel_size > 1: + grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size + if trl.importance_sampling_level is not None: grpo_args_kwargs["importance_sampling_level"] = ( trl.importance_sampling_level ) - if cfg.sequence_parallel_degree > 1: - grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree - if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 5c8b1a33b..2ea52998e 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """Axolotl GRPO Config for GRPO training""" - sequence_parallel_degree: int | None = None + context_parallel_size: int | None = None diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py index ebc6e19e2..df679a6d2 100644 --- a/src/axolotl/core/trainers/grpo/sampler.py +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): - Data is properly distributed across SP groups. In the table below, the values represent dataset indices. Each SP group has - `sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 + `context_parallel_size = 2` GPUs working together on the same data. There are 2 SP groups (SP0 and SP1), with `world_size = 4` total GPUs. Sequence Parallel Groups @@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: Rank of current process. batch_size: Number of samples per batch. repeat_count: How many times to repeat the full sampling process. - sequence_parallel_degree: Number of ranks in a sequence parallel group. + context_parallel_size: Number of ranks in a sequence parallel group. shuffle: Whether to shuffle the dataset. seed: Random seed for shuffling. drop_last: Whether to drop the last incomplete batch. @@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: int, batch_size: int = 1, repeat_count: int = 1, - sequence_parallel_degree: int = 1, + context_parallel_size: int = 1, shuffle: bool = True, seed: int = 0, drop_last: bool = False, @@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler): self.rank = rank # Sequence parallelism parameters - self.sequence_parallel_degree = sequence_parallel_degree - self.num_sp_groups = world_size // sequence_parallel_degree - self.sp_group_id = rank // sequence_parallel_degree + self.context_parallel_size = context_parallel_size + self.num_sp_groups = world_size // context_parallel_size + self.sp_group_id = rank // context_parallel_size # Adjust dataset size for distributed sampling self.num_samples = len(self.dataset) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 70b3cf3b5..49caa6406 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import ( + DistributedParallelMixin, + RngLoaderMixin, + SchedulerMixin, +) from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.monkeypatch.ring_attn import get_ring_attn_group @@ -53,7 +57,12 @@ if is_peft_available(): class AxolotlGRPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + GRPOTrainer, ): """Extend the base GRPOTrainer for axolotl helpers""" @@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Get number of SP groups (number of processes divided by SP degree) num_processes = self.accelerator.num_processes - num_sp_groups = num_processes // self.args.sequence_parallel_degree + num_sp_groups = num_processes // self.args.context_parallel_size # Calculate batch size per SP group (not per process) sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups @@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): if self.num_generations not in possible_values: raise ValueError( - f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " + f"With sequence parallelism (degree {self.args.context_parallel_size}), " f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"must be evenly divisible by the number of generations per prompt " f"({self.num_generations}). Given the current eval batch size, " @@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): rank=self.rank, batch_size=effective_batch_size // self.num_generations - // self.args.sequence_parallel_degree, + // self.args.context_parallel_size, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, - sequence_parallel_degree=self.args.sequence_parallel_degree, + context_parallel_size=self.args.context_parallel_size, shuffle=True, seed=self.args.seed, drop_last=True, @@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: return dataloader # Otherwise prepare with accelerator @@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate sequence parallel group information world_size = self.accelerator.num_processes - sequence_parallel_degree = self.args.sequence_parallel_degree - num_sp_groups = world_size // sequence_parallel_degree + context_parallel_size = self.args.context_parallel_size + num_sp_groups = world_size // context_parallel_size # Since processes in the same SP group have the same prompts, we need to ensure # we only take one copy of each prompt from each SP group ordered_set_of_prompts = [] for sp_group_id in range(num_sp_groups): # Get the first process from each SP group (typically the group leader) - group_leader_rank = sp_group_id * sequence_parallel_degree + group_leader_rank = sp_group_id * context_parallel_size # Extract prompts from this SP group, accounting for num_generations duplicates # We only need prompts from one rank in each SP group @@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. ordered_set_of_prompts = all_prompts_text[ - :: self.num_generations * self.args.sequence_parallel_degree + :: self.num_generations * self.args.context_parallel_size ] with profiling_context(self, "vLLM.generate"): @@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ) else: completion_ids = [None] * ( - len(all_prompts_text) // self.args.sequence_parallel_degree + len(all_prompts_text) // self.args.context_parallel_size ) # Broadcast the completions from the main process to all processes completion_ids = broadcast_object_list(completion_ids, from_process=0) # Determine the appropriate slice based on sequence parallelism - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate SP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size @@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): advantages = advantages / (std_grouped_rewards + 1e-4) # Slice to keep only the local part of the data - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate SP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size diff --git a/src/axolotl/core/trainers/mamba.py b/src/axolotl/core/trainers/mamba.py index 38792e389..b475b26d9 100644 --- a/src/axolotl/core/trainers/mamba.py +++ b/src/axolotl/core/trainers/mamba.py @@ -5,6 +5,7 @@ import torch from axolotl.core.trainers.base import AxolotlTrainer +# pylint: disable=too-many-ancestors class AxolotlMambaTrainer(AxolotlTrainer): """Mamba specific trainer to handle loss calculation""" diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 453810aac..b54577765 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -5,6 +5,7 @@ from .activation_checkpointing import ActivationOffloadingMixin from .checkpoints import CheckpointSaveMixin +from .distributed_parallel import DistributedParallelMixin from .optimizer import OptimizerMixin from .packing import PackingMixin from .rng_state_loader import RngLoaderMixin diff --git a/src/axolotl/core/trainers/mixins/checkpoints.py b/src/axolotl/core/trainers/mixins/checkpoints.py index 8f994d78c..4042ef9f1 100644 --- a/src/axolotl/core/trainers/mixins/checkpoints.py +++ b/src/axolotl/core/trainers/mixins/checkpoints.py @@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer): def _save_optimizer_and_scheduler(self, output_dir): try: super()._save_optimizer_and_scheduler(output_dir) - except NotImplementedError as exc: - LOG.warning( + except (NotImplementedError, KeyError) as exc: + # TODO: fix fsdp2 optimizer saving + LOG.warning_once( f"Trainer does not support saving optimizer and scheduler: {exc}\n" "Optimizer and scheduler states were not saved - resuming from checkpoints " - "for this training run will not be possible." + "for this training run will not be possible.", + main_process_only=True, ) diff --git a/src/axolotl/core/trainers/mixins/distributed_parallel.py b/src/axolotl/core/trainers/mixins/distributed_parallel.py new file mode 100644 index 000000000..d0f0f53df --- /dev/null +++ b/src/axolotl/core/trainers/mixins/distributed_parallel.py @@ -0,0 +1,20 @@ +""" +Mixin for correctly saving fsdp +""" + +from transformers import Trainer + + +class DistributedParallelMixin(Trainer): + """ + Mixin for correctly saving fsdp + """ + + def _save(self, output_dir: str | None = None, state_dict=None): + if ( + state_dict is None + and self.accelerator.parallelism_config + and self.accelerator.parallelism_config.dp_shard_enabled + ): + state_dict = self.accelerator.get_state_dict(self.model) + super()._save(output_dir, state_dict=state_dict) diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index cb97f37d7..c5f19a6fe 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -8,13 +8,18 @@ from trl import ( RewardTrainer, ) -from axolotl.core.trainers.mixins import RngLoaderMixin +from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin class AxolotlORPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + ORPOTrainer, ): """ Extend the base ORPOTrainer for axolotl helpers @@ -24,7 +29,12 @@ class AxolotlORPOTrainer( class AxolotlKTOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + KTOTrainer, ): """ Extend the base KTOTrainer for axolotl helpers @@ -34,7 +44,12 @@ class AxolotlKTOTrainer( class AxolotlCPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + CPOTrainer, ): """ Extend the base CPOTrainer for axolotl helpers @@ -44,7 +59,12 @@ class AxolotlCPOTrainer( class AxolotlRewardTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + RewardTrainer, ): """ Extend the base RewardTrainer for axolotl helpers @@ -54,7 +74,12 @@ class AxolotlRewardTrainer( class AxolotlPRMTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + PRMTrainer, ): """ Extend the base trl.PRMTrainer for axolotl helpers diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 7ec43333a..c454b2a2c 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss +# pylint: disable=too-many-ancestors class AxolotlKDTrainer(AxolotlTrainer): """ Custom trainer subclass for Knowledge Distillation (KD) diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 0460bdbf5..d5bb10cfd 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -16,8 +16,6 @@ Module for handling LIGER input arguments. """ -from typing import Optional - from pydantic import BaseModel, model_validator from axolotl.utils.logging import get_logger @@ -30,13 +28,13 @@ class LigerArgs(BaseModel): Input args for LIGER. """ - liger_rope: Optional[bool] = None - liger_rms_norm: Optional[bool] = None - liger_layer_norm: Optional[bool] = None - liger_swiglu: Optional[bool] = None - liger_glu_activation: Optional[bool] = None - liger_cross_entropy: Optional[bool] = None - liger_fused_linear_cross_entropy: Optional[bool] = None + liger_rope: bool | None = None + liger_rms_norm: bool | None = None + liger_layer_norm: bool | None = None + liger_swiglu: bool | None = None + liger_glu_activation: bool | None = None + liger_cross_entropy: bool | None = None + liger_fused_linear_cross_entropy: bool | None = None @model_validator(mode="before") @classmethod @@ -66,3 +64,20 @@ class LigerArgs(BaseModel): "You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`." ) return data + + @model_validator(mode="before") + @classmethod + def check_liger_rms_norm_tensor_parallel(cls, data): + if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1: + raise ValueError( + "`liger_rms_norm` is incompatible with tensor parallelism, " + "see https://github.com/linkedin/Liger-Kernel/issues/826" + ) + return data + + @model_validator(mode="after") + def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self): + # TODO @SalmanMohammadi this is a larger fix - investigate + if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy: + raise ValueError("Tensor parallelism is not compatible with liger losses.") + return self diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 4fc005457..05039c9ee 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -13,7 +13,8 @@ import peft import torch import transformers import transformers.modeling_utils -from accelerate import init_empty_weights +from accelerate import PartialState, init_empty_weights +from accelerate.parallelism_config import ParallelismConfig from peft import ( PeftConfig, PeftMixedModel, @@ -48,10 +49,7 @@ from axolotl.loaders.utils import ( from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import ( - get_device_count, - get_device_type, -) +from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size from axolotl.utils.logging import get_logger from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.schemas.enums import RLType @@ -87,6 +85,9 @@ class ModelLoader: `AutoModelForCausalLM`). """ + use_parallel_config: bool | None = False + parallelism_config: ParallelismConfig | None = None + def __init__( self, cfg: DictDefault, @@ -183,6 +184,20 @@ class ModelLoader: def _apply_pre_model_load_setup(self): """Apply patches and setup configurations before model loading.""" + if self.use_parallel_config is not None: + self.use_parallel_config = ( + self.cfg.fsdp_config + or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1) + or ( + self.cfg.context_parallel_size + and self.cfg.context_parallel_size > 1 + ) + ) + if self.cfg.fsdp_config and self.cfg.fsdp_version != 2: + self.use_parallel_config = False + + if self.use_parallel_config: + self._set_parallel_config() self._set_auto_model_loader() self._set_device_map_config() if self.cfg.revision_of_model: @@ -390,6 +405,86 @@ class ModelLoader: gc.collect() torch.cuda.empty_cache() + @staticmethod + def _get_parallel_config_kwargs( + world_size: int, + tensor_parallel_size: int = 1, + context_parallel_size: int = 1, + dp_shard_size: int | None = None, + dp_replicate_size: int | None = None, + is_fsdp: bool = False, + ): + pc_kwargs = {} + remaining_world_size = world_size + + if tensor_parallel_size and tensor_parallel_size > 1: + pc_kwargs["tp_size"] = tensor_parallel_size + remaining_world_size = remaining_world_size // tensor_parallel_size + + if context_parallel_size and context_parallel_size > 1: + pc_kwargs["cp_size"] = context_parallel_size + remaining_world_size = remaining_world_size // context_parallel_size + + if dp_shard_size is None and dp_replicate_size in (None, 1): + if remaining_world_size > 1: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1 + + if dp_replicate_size and dp_replicate_size > 1: + pc_kwargs["dp_replicate_size"] = dp_replicate_size + remaining_world_size = remaining_world_size // dp_replicate_size + + if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1: + if not is_fsdp: + raise ValueError( + "dp_shard_size was configured without a corresponding fsdp_config! " + "Please ensure you have configured FSDP using fsdp_config." + ) + pc_kwargs["dp_shard_size"] = dp_shard_size + remaining_world_size = remaining_world_size // dp_shard_size + if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs: + pc_kwargs["dp_replicate_size"] = remaining_world_size + remaining_world_size = 1 + + if remaining_world_size > 1: + if "dp_shard_size" not in pc_kwargs and is_fsdp: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1 + + if remaining_world_size > 1: + raise ValueError( + f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n" + f"{pc_kwargs}" + ) + + return pc_kwargs + + def _set_parallel_config(self): + """Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator""" + pc_kwargs = ModelLoader._get_parallel_config_kwargs( + get_world_size(), + self.cfg.tensor_parallel_size, + self.cfg.context_parallel_size, + self.cfg.dp_shard_size, + self.cfg.dp_replicate_size, + bool(self.cfg.fsdp or self.cfg.fsdp_config), + ) + + if pc_kwargs: + self.parallelism_config = ParallelismConfig( + **pc_kwargs, + ) + device_mesh = self.parallelism_config.build_device_mesh("cuda") + partial_state = PartialState() + # fmt: off + partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access + self.parallelism_config + ) + partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access + device_mesh + ) + # fmt: on + def _set_auto_model_loader(self): """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` (set at `__init__`). When using a multimodal model, `self.auto_model_loader` @@ -622,6 +717,14 @@ class ModelLoader: def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False + + if self.cfg.tensor_parallel_size > 1: + self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size + self.model_kwargs["tp_plan"] = "auto" + self.model_kwargs["device_mesh"] = PartialState().device_mesh + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] # not compatible with `tp_plan` + if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: skip_move_to_device = True @@ -734,6 +837,14 @@ class ModelLoader: if is_deepspeed_zero3_enabled(): skip_move_to_device = True + # pylint: disable=protected-access + if self.cfg.tensor_parallel_size > 1: + # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh + # TODO(wing): remove once 4.54.1 is released + if self.model._tp_size != self.cfg.tensor_parallel_size: + self.model._tp_size = self.cfg.tensor_parallel_size + self.model._device_mesh = self.model_kwargs["device_mesh"] + return skip_move_to_device def _set_z3_leaf_modules(self): diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 186681521..9eb779113 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -49,6 +49,7 @@ class PatchManager: def apply_pre_model_load_patches(self): """Apply pre-model load patches based on config.""" + self._apply_transformers_patches() # self._apply_flex_attention_patches() self._apply_flash_attention_patches() self._apply_chunked_cross_entropy_patch() @@ -64,13 +65,19 @@ class PatchManager: self._patch_llama_derived_model() self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() - self._apply_sequence_parallel_patches() def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" self._apply_tiled_mlp(self.cfg.model_config_type) self._apply_voxtral_patches() + def _apply_transformers_patches(self): + from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import ( + patch_prepare_from_posids, + ) + + patch_prepare_from_posids() + def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" self._apply_llama_flash_attn_patches(model) @@ -253,17 +260,6 @@ class PatchManager: has_remote_code=has_remote_code, ) - def _apply_sequence_parallel_patches(self): - """Apply sequence parallelism patches.""" - if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.ring_attn.patch import ( - patch_prepare_data_loader, - patch_prepare_device_mesh, - ) - - patch_prepare_data_loader() - patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp) - def _apply_tiled_mlp(self, model_type: str): if self.cfg.tiled_mlp: from axolotl.monkeypatch.tiled_mlp import ( diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 803659232..af262d18f 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: auto_wrap_policy=fsdp2_plugin.auto_wrap_policy, ) + mesh = getattr(accelerator.state, "device_mesh", None) + fsdp2_kwargs = { "reshard_after_forward": fsdp2_plugin.reshard_after_forward, "offload_policy": fsdp2_plugin.cpu_offload, # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), + "mesh": ( + mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)] + if mesh is not None + else None + ), } - model_has_params4bit = False for _, param in model.named_parameters(): # this is a temporary fix whereby loading models with bnb params cannot be moved from diff --git a/src/axolotl/monkeypatch/ring_attn/__init__.py b/src/axolotl/monkeypatch/ring_attn/__init__.py index 5833b9ce4..736378b16 100644 --- a/src/axolotl/monkeypatch/ring_attn/__init__.py +++ b/src/axolotl/monkeypatch/ring_attn/__init__.py @@ -5,18 +5,14 @@ from .patch import ( get_ring_attn_group, - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, + register_ring_attn_from_device_mesh, set_ring_attn_group, update_ring_attn_params, ) __all__ = ( "get_ring_attn_group", - "patch_prepare_data_loader", - "patch_prepare_device_mesh", - "register_ring_attn", + "register_ring_attn_from_device_mesh", "set_ring_attn_group", "update_ring_attn_params", ) diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 9c9ba4553..934687a16 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader sequence parallelism training. """ -import inspect import os from typing import Optional -import accelerate import torch import torch.distributed as dist +from torch.distributed import DeviceMesh try: from transformers.modeling_flash_attention_utils import _flash_supports_window @@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc LOG = get_logger(__name__) - RING_ATTN_GROUP = None -ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 - submesh_dp_size = 1 - submesh_tp_size = 1 - if "tp" in torch_device_mesh.mesh_dim_names: - submesh_tp_size = torch_device_mesh["tp"].size() - if "dp" in torch_device_mesh.mesh_dim_names: - submesh_dp_size = torch_device_mesh["dp"].size() - if "fsdp" in torch_device_mesh.mesh_dim_names: - submesh_fsdp_size = torch_device_mesh["fsdp"].size() - process_index = process_index // submesh_tp_size""" - -NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 - submesh_dp_size = 1 - submesh_tp_size = 1 - submesh_cp_size = 1 - if "cp" in torch_device_mesh.mesh_dim_names: - submesh_cp_size = torch_device_mesh["cp"].size() - if "tp" in torch_device_mesh.mesh_dim_names: - submesh_tp_size = torch_device_mesh["tp"].size() - if "dp" in torch_device_mesh.mesh_dim_names: - submesh_dp_size = torch_device_mesh["dp"].size() - if "fsdp" in torch_device_mesh.mesh_dim_names: - submesh_fsdp_size = torch_device_mesh["fsdp"].size() - process_index = process_index // (submesh_tp_size * submesh_cp_size)""" - def get_ring_attn_group() -> dist.ProcessGroup: """Getter for ring attention group on this rank.""" if RING_ATTN_GROUP is None: - raise RuntimeError("register_ring_attn() not yet called") + raise RuntimeError("register_ring_attn_from_device_mesh() not yet called") return RING_ATTN_GROUP @@ -161,15 +134,17 @@ def create_ring_flash_attention_forward( ] -def register_ring_attn( - sequence_parallel_degree: int, +def register_ring_attn_from_device_mesh( + device_mesh: "DeviceMesh", + context_parallel_dim: tuple[str, ...], heads_k_stride: int | None, ring_attn_func: RingAttnFunc | None, ): - """Create ring attention group and substitute flash attn with ring flash attn. + """Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn. Args: - sequence_parallel_degree: Sequence parallelism factor. + device_mesh: DeviceMesh object containing the parallelism topology. + context_parallel_dim: Name of the sequence parallel dimension in the device mesh. heads_k_stride: Sequence parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample @@ -177,44 +152,39 @@ def register_ring_attn( `batch` function. """ rank = dist.get_rank() - world_size = dist.get_world_size() + + LOG.info( + f"Enabling ring attention sequence parallelism using DeviceMesh " + f"dimension '{context_parallel_dim}'", + main_process_only=True, + ) + + # Extract the sequence parallel submesh + try: + sequence_mesh = device_mesh[context_parallel_dim] + except (KeyError, IndexError) as e: + raise ValueError( + f"Dimension '{context_parallel_dim}' not found in device_mesh. " + f"Available dimensions: {device_mesh.mesh_dim_names}" + ) from e + + # Get the process group for context parallelism + sequence_pg = sequence_mesh.get_group() + context_parallel_size = sequence_mesh.size() if rank == 0: LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" + f"Sequence parallel degree: {context_parallel_size}, " + f"mesh shape: {sequence_mesh.mesh.shape}" ) - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must be less than or equal to world_size ({world_size})" - ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must evenly divide world_size ({world_size})" - ) + # Log which ranks are in the current process group + if sequence_pg != dist.GroupMember.WORLD: + ranks_in_group = dist.get_process_group_ranks(sequence_pg) + LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}") - # Assign ranks to sequence parallel groups - group_assignments = {} - for i in range(world_size // sequence_parallel_degree): - ring_attn_ranks = list( - range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, - ) - ) - group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") - - # Track which GPUs are in which groups - for r in ring_attn_ranks: - group_assignments[r] = i - - if rank in ring_attn_ranks: - set_ring_attn_group(group) - - # Log the GPU group assignments - if rank == 0: - LOG.info(f"Sequence parallel group assignments: {group_assignments}") + # Set the ring attention group + set_ring_attn_group(sequence_pg) if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: # fmt: off @@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None): cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) - - -def patch_prepare_data_loader(): - """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. - - Raises: - RuntimeError: If source code to patch does not exist. - """ - original_fn = accelerate.data_loader.prepare_data_loader - original_source = inspect.getsource(original_fn) - - if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source: - raise RuntimeError( - "SP patch failed - target snippet not found. " - "Check accelerate's version or update the patch." - ) - - patched_source = original_source.replace( - ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE - ) - - items_to_import = [] - for item in dir(accelerate.data_loader): - if item in patched_source: - items_to_import.append(item) - - # Create a new function from the patched source - namespace = {} - exec( # pylint: disable=exec-used # nosec B102 - f"from accelerate.data_loader import ({', '.join(items_to_import)})", - globals(), - ) - exec( # pylint: disable=exec-used # nosec B102 - patched_source, globals(), namespace - ) - - patched_function = namespace["prepare_data_loader"] - original_fn.__code__ = patched_function.__code__ - - LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") - - -def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False): - """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh - that includes sequence parallelism with the specified degree. - - Args: - sequence_parallel_degree: The degree of sequence parallelism to use. - fsdp: Whether to use FSDP. - """ - - def _prepare_device_mesh(self): - """Prepare the device mesh for distributed training. The dataloader will - determine how to load data based on the device mesh. - """ - if self.state.torch_tp_plugin: - return self.state.torch_tp_plugin.torch_device_mesh - if ( - self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED - and hasattr(self.state, "ds_device_mesh") - ): - return self.state.ds_device_mesh - - # Create device mesh with sequence parallelism - world_size = dist.get_world_size() - mesh_shape = ( - world_size // sequence_parallel_degree, - sequence_parallel_degree, - ) - device_ids = list(range(world_size)) - - # NOTE: We use "cp" instead of "sp" to match the PyTorch native "context - # parallelism" implementation naming. - # NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we - # only use "fsdp" and "cp" for the device mesh. - return dist.DeviceMesh( - "cuda", - torch.tensor(device_ids).reshape(mesh_shape), - mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"), - ) - - # Replace the original method with our new method - # pylint: disable=protected-access - accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh - - LOG.info( - "Successfully patched Accelerator._prepare_device_mesh " - f"with sequence_parallel_degree={sequence_parallel_degree}" - ) diff --git a/src/axolotl/monkeypatch/transformers/__init__.py b/src/axolotl/monkeypatch/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py b/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py new file mode 100644 index 000000000..1bd8ac6bc --- /dev/null +++ b/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py @@ -0,0 +1,87 @@ +""" +Monkey patch to fix transformers.modeling_flash_attention_utils. + +see https://github.com/huggingface/transformers/pull/39653/files +""" + +import sys + +import torch + + +def _prepare_from_posids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + NOTE: ideally cumulative lengths should be prepared at the data collator stage + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + + position_ids = position_ids.flatten() + indices_q = torch.arange( + position_ids.size(0), device=position_ids.device, dtype=torch.int32 + ) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor( + position_ids.size(), device=position_ids.device, dtype=torch.int32 + ), + ) + ) + # NOTE: With torch compile, this will cause a graph break if you don't set + # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call + # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. + # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` + # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length = cu_seq_lens.diff().max().item() + return ( + query, + key, + value, + indices_q, + (cu_seq_lens, cu_seq_lens), + (max_length, max_length), + ) + + +def patch_prepare_from_posids(): + import transformers.modeling_flash_attention_utils + + transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access + _prepare_from_posids + ) + setattr( + sys.modules["transformers.modeling_flash_attention_utils"], + "_prepare_from_posids", + _prepare_from_posids, + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b507c2c7b..41f184abc 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -205,7 +205,7 @@ def execute_training( ) ) - if cfg.sequence_parallel_degree > 1: + if cfg.context_parallel_size > 1: models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: models.append(trainer.ref_model) @@ -213,7 +213,7 @@ def execute_training( stack.enter_context( SequenceParallelContextManager( models=models, - sequence_parallel_degree=cfg.sequence_parallel_degree, + context_parallel_size=cfg.context_parallel_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index dae53eddf..dd3a85b8c 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -57,10 +57,10 @@ def gpu_memory_usage(device=0): @check_cuda_device((0.0, 0.0, 0.0)) def gpu_memory_usage_all(device=0): - usage = torch.cuda.memory_allocated(device) / 1024.0**3 - reserved = torch.cuda.memory_reserved(device) / 1024.0**3 - smi = gpu_memory_usage_smi(device) - return usage, reserved - usage, max(0, smi - reserved) + active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3 + allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3 + reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3 + return active, allocated, reserved def mps_memory_usage_all(): @@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0): return 0.0 -def log_gpu_memory_usage( - log: logging.Logger | logging.LoggerAdapter, - msg: str = "", - device: int | torch.device = 0, -): +def get_gpu_memory_usage(device: int | torch.device = 0): cur_device_type = str(get_device_type()) if torch.backends.mps.is_available(): usage, cache, misc = mps_memory_usage_all() elif "npu" in cur_device_type and is_torch_npu_available(): usage, cache, misc = npu_memory_usage_all(device) - elif "gpu" in cur_device_type and torch.cuda.is_available(): + elif "cuda" in cur_device_type and torch.cuda.is_available(): usage, cache, misc = gpu_memory_usage_all(device) else: + return 0.0, 0.0, 0.0 + + return usage, cache, misc + + +def log_gpu_memory_usage( + log: logging.Logger | logging.LoggerAdapter, + msg: str = "", + device: int | torch.device = 0, +): + try: + active, allocated, reserved = get_gpu_memory_usage(device) + except ValueError: + # likely CPU, ignore return + cur_device_type = str(get_device_type()) extras = [] - if cache > 0: - extras.append(f"+{cache:.03f}GB cache") - if misc > 0: - extras.append(f"+{misc:.03f}GB misc") - msg = f"{cur_device_type} memory usage:" if not msg else msg - log.info( - f"{msg} {usage:.03f}GB ({', '.join(extras)})", + if allocated > 0: + extras.append(f"+{allocated:.03f}GB allocated") + if reserved > 0: + extras.append(f"+{reserved:.03f}GB reserved") + msg = f"{cur_device_type} memory active:" if not msg else msg + log.debug( + f"{msg} {active:.03f}GB ({', '.join(extras)})", stacklevel=2, ) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index c64d8d351..63799c734 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -35,7 +35,7 @@ from transformers.trainer_utils import ( from trl.models import unwrap_model_for_generation from axolotl.utils import is_comet_available, is_mlflow_available -from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.distributed import ( barrier, @@ -100,7 +100,6 @@ class GPUStatsCallback( def __init__(self, cfg): self.cfg = cfg - self.logged = False def on_step_end( self, @@ -109,9 +108,21 @@ class GPUStatsCallback( control: TrainerControl, **kwargs, ) -> TrainerControl: - if not self.logged and state.global_step > 1: - log_gpu_memory_usage(LOG, "while training", self.cfg.device) - self.logged = True + if state.global_step > 0: + if self.cfg.use_wandb and state.is_world_process_zero: + try: + active, allocated, reserved = get_gpu_memory_usage() + wandb.log( + { + "memory/max_memory_active": active, + "memory/max_memory_allocated": allocated, + "memory/device_memory_reserved": reserved, + }, + step=state.global_step, + ) + except ValueError: + pass + log_gpu_memory_usage(LOG, "", self.cfg.device) return control diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 1ac805a73..949c76f49 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -5,6 +5,7 @@ import inspect import torch import torch.distributed as dist +from accelerate import PartialState from torch import nn from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast @@ -12,7 +13,7 @@ from transformers.utils import ModelOutput from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, - register_ring_attn, + register_ring_attn_from_device_mesh, update_ring_attn_params, ) from axolotl.utils.schemas.enums import RingAttnFunc @@ -150,9 +151,18 @@ def apply_sequence_parallelism( if "num_items_in_batch" in batch: # Approximation; this needed since num_items_in_batch may be counted across # all samples in a gradient accumulated batch, not on a per-step basis. + local_valid_tokens = (batch["labels"] != -100).sum() + + # All-reduce across sequence parallel ranks to get global token count + cp_group = get_ring_attn_group() + global_valid_tokens = local_valid_tokens.clone() + # we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens + dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group) + global_valid_tokens = int(global_valid_tokens.item()) + batch["num_items_in_batch"] = ( - batch["labels"] != -100 - ).sum() * gradient_accumulation_steps + global_valid_tokens * gradient_accumulation_steps + ) return batch, original_seq_len, pad_len @@ -167,7 +177,7 @@ class SequenceParallelContextManager: Args: models: List of models to apply sequence parallelism to pre- and post- forward hooks. - sequence_parallel_degree: Number of processes to split sequences over. + context_parallel_size: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. heads_k_stride: Sequence parallelism K head stride size. Passed through to @@ -179,14 +189,14 @@ class SequenceParallelContextManager: def __init__( self, models: list[nn.Module], - sequence_parallel_degree: int, + context_parallel_size: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, heads_k_stride: int | None, gather_outputs: bool, ): self.models = models - self.sequence_parallel_degree = sequence_parallel_degree + self.context_parallel_size = context_parallel_size self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride @@ -230,8 +240,10 @@ class SequenceParallelContextManager: def _register_ring_attn(self): # Initialize ring attn for sequence parallelism - register_ring_attn( - sequence_parallel_degree=self.sequence_parallel_degree, + partial_state = PartialState() + register_ring_attn_from_device_mesh( + device_mesh=partial_state.device_mesh, + context_parallel_dim=("cp",), heads_k_stride=self.heads_k_stride, ring_attn_func=self.ring_attn_func, ) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 7877e5abf..21c8e472b 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -430,10 +430,11 @@ def save_preprocessed_dataset( num_shards=cfg.num_dataset_shards_to_save, ) else: + min_rows_per_proc = 256 os.makedirs(prepared_ds_path, exist_ok=True) dataset.save_to_disk( str(prepared_ds_path), - num_proc=min(max(1, len(dataset) // 8), num_workers), + num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers), max_shard_size=None, num_shards=cfg.num_dataset_shards_to_save, ) diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 1cc609a68..3c83c87cb 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -2,12 +2,15 @@ utils to get GPU info for the current environment """ +from importlib.metadata import version + from accelerate.utils.environment import ( check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, ) from accelerate.utils.environment import ( get_gpu_info, ) +from packaging.version import Version, parse def check_cuda_p2p_ib_support(): @@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support(): except Exception: # pylint: disable=broad-except # nosec pass return True + + +def get_package_version(package: str) -> Version: + version_str = version(package) + return parse(version_str) + + +def is_package_version_ge(package: str, version_: str) -> bool: + package_version = get_package_version(package) + return package_version >= parse(version_) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index ee8640f41..af62c0a4f 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput. import gc import math +import time from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context from typing import Iterable, Iterator, Union @@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler): _sampled_lens = [] for _ in range(self.num_count_samples): self._batches = None # Reset cached batches + # log timer for generating batches + start_time = time.time() _sampled_lens.append(len(self.generate_batches(set_stats=False))) + LOG.debug(f"generate_batches time: {time.time() - start_time}") len_batches = min(_sampled_lens) # Gather minimum across all ranks diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f8746692c..1d089ba41 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -651,7 +651,23 @@ class AxolotlInputConfig( }, ) + dp_shard_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of devices to shard across. If not set, will use all available devices." + }, + ) + dp_replicate_size: int | None = Field( + default=None, + json_schema_extra={"description": "Number of devices to replicate across."}, + ) sequence_parallel_degree: int | None = Field( + default=None, + json_schema_extra={ + "description": "Deprecated: use `context_parallel_size` instead" + }, + ) + context_parallel_size: int | None = Field( default=None, json_schema_extra={ "description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details." diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 063690c59..502c18e7d 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -673,7 +673,7 @@ class RLValidationMixin: data.get("rl") == "grpo" and data.get("trl", {}) and data.get("trl").get("use_liger_loss") - and data.get("sequence_parallel_degree", 1) > 1 + and data.get("context_parallel_size", 1) > 1 ): raise ValueError("GRPO + SP + Liger not currently supported") return data @@ -880,51 +880,35 @@ class OptimizationValidationMixin: return self - @model_validator(mode="after") - def check_fsdp_sharded_state_dict_w_safetensors(self): - if ( - hasattr(self, "fsdp_config") - and self.fsdp_config - and hasattr(self, "save_safetensors") - and self.save_safetensors - and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT" - and str(getattr(self, "fsdp_version", "1")) != "2" - ): - raise ValueError( - "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" - ) - return self - @model_validator(mode="before") @classmethod def check_tensor_parallel_size_update_ds_json(cls, data): tensor_parallel_size = data.get("tensor_parallel_size") if tensor_parallel_size is not None and tensor_parallel_size > 1: - if not data.get("deepspeed"): - raise ValueError( - "Tensor parallelism (TP) is only supported with DeepSpeed" - ) - with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: - ds_config = json.load(ds_fin) - should_save = False - if "tensor_parallel" not in ds_config: - ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size} - should_save = True - if ( - "gather_16bit_weights_on_model_save" - not in ds_config["zero_optimization"] - ): - ds_config["zero_optimization"][ + if data.get("deepspeed"): + with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: + ds_config = json.load(ds_fin) + should_save = False + if "tensor_parallel" not in ds_config: + ds_config["tensor_parallel"] = { + "autotp_size": tensor_parallel_size + } + should_save = True + if ( "gather_16bit_weights_on_model_save" - ] = True - should_save = True - if should_save: - temp_dir = tempfile.mkdtemp() - with open( - Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" - ) as ds_fout: - json.dump(ds_config, ds_fout, indent=4) - data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") + not in ds_config["zero_optimization"] + ): + ds_config["zero_optimization"][ + "gather_16bit_weights_on_model_save" + ] = True + should_save = True + if should_save: + temp_dir = tempfile.mkdtemp() + with open( + Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" + ) as ds_fout: + json.dump(ds_config, ds_fout, indent=4) + data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") return data @@ -1205,13 +1189,18 @@ class ComplexValidationMixin: return self @model_validator(mode="after") - def check_sequence_parallel_degree(self): - if not self.sequence_parallel_degree: - self.sequence_parallel_degree = 1 - elif self.sequence_parallel_degree > 1: + def check_context_parallel_size(self): + if self.sequence_parallel_degree and not self.context_parallel_size: + LOG.warning( + "`sequence_parallel_degree` is deprecated, use `context_parallel_size`" + ) + self.context_parallel_size = self.sequence_parallel_degree + if not self.context_parallel_size: + self.context_parallel_size = 1 + elif self.context_parallel_size > 1: if not self.flash_attention: raise ValueError( - "flash_attention: true must be set with sequence_parallel_degree > 1" + "flash_attention: true must be set with context_parallel_size > 1" ) if self.sample_packing and self.micro_batch_size > 1: @@ -1221,17 +1210,23 @@ class ComplexValidationMixin: ) try: + import transformers.modeling_flash_attention_utils + + # pylint: disable=protected-access + transformers.modeling_flash_attention_utils._flash_supports_window_size = ( + transformers.modeling_flash_attention_utils._flash_supports_window + ) import ring_flash_attn # noqa: F401 # pylint:disable=unused-import except ImportError as exception: raise ImportError( - "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " + "context_parallel_size > 1 but ring_flash_attn is not installed. " "Please install it with `pip install axolotl[ring-flash-attn] " "or `pip install ring-flash-attn>=0.1.4`." ) from exception LOG.warning( "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={self.sequence_parallel_degree}. " + f"context_parallel_size={self.context_parallel_size}. " "Please note that logged losses may differ slightly to the non-SP " "losses due to transformers Trainer implementation details. " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " @@ -1242,7 +1237,7 @@ class ComplexValidationMixin: @model_validator(mode="after") def validate_ring_attn_func(self): - if getattr(self, "sequence_parallel_degree", 1) == 1: + if getattr(self, "context_parallel_size", 1) == 1: return self if self.ring_attn_func is not None: @@ -1259,6 +1254,20 @@ class ComplexValidationMixin: return self +class DistributedValidationMixin: + """validation for distributed training.""" + + @model_validator(mode="after") + def check_tensor_parallel_optimizer(self): + if self.tensor_parallel_size > 1: + if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]: + raise ValueError( + "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" + ) + + return self + + # pylint: disable=too-many-ancestors class ValidationMixin( DatasetValidationMixin, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8371b2dd7..90ae1a889 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size * cfg.tensor_parallel_size ) LOG.debug( @@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.floor( data_loader_len * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size * cfg.tensor_parallel_size ) ) @@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.ceil( len(train_dataset) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size * cfg.tensor_parallel_size / cfg.batch_size ) diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 040152beb..5f1aec8ff 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -64,7 +64,7 @@ def fixture_base_cfg(): "dataloader_num_workers": 1, "dataloader_pin_memory": True, "dataloader_prefetch_factor": 2, - "sequence_parallel_degree": 1, + "context_parallel_size": 1, "tensor_parallel_size": 1, # Dtype "fp16": False, diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index 80098e684..a005e6742 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -67,7 +67,7 @@ class TestSequenceParallelism: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "ring_attn_func": ring_attn_func, "save_first_step": False, } @@ -105,13 +105,13 @@ class TestSequenceParallelism: (True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func (False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func # (False, 2, True, "batch_zigzag", 2.5), - (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func + # (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func ], ids=[ "sample_packing, varlen_llama3 ring_attn_func", "no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func", # "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func", - "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", + # "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", ], ) def test_sequence_parallel_training( diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index d022ae2d9..92e0f7040 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "sequence_len": 1024, "special_tokens": { diff --git a/tests/e2e/multigpu/test_fp8_fsdp2.py b/tests/e2e/multigpu/test_fp8_fsdp2.py index 6423f5e2e..f7fa29a31 100644 --- a/tests/e2e/multigpu/test_fp8_fsdp2.py +++ b/tests/e2e/multigpu/test_fp8_fsdp2.py @@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault -from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 +from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0 AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent @@ -51,6 +51,7 @@ class TestFP8FSDP2: """Test class for FP8 mixed precision with FSDP2 functionality.""" @require_torch_2_7_0 + @require_hopper def test_fp8_fsdp2_smoke(self, temp_dir): """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training""" cfg = DictDefault( diff --git a/tests/e2e/multigpu/test_tp.py b/tests/e2e/multigpu/test_tp.py new file mode 100644 index 000000000..87a1c6339 --- /dev/null +++ b/tests/e2e/multigpu/test_tp.py @@ -0,0 +1,69 @@ +"""multigpu e2e test for tensor parallelism.""" + +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_tensorboard, require_torch_2_7_0 + + +class TestTensorParallel: + """Test class for Tensor Parallel functionality.""" + + @pytest.mark.skip( + reason="TP doesn't work with models with tied weights (embeddings)" + ) + @require_torch_2_7_0 + def test_fft_sft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "tensor_parallel_size": 2, + "lr_scheduler": "cosine", + "flash_attention": True, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high" + ) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py deleted file mode 100644 index 4a2c69d45..000000000 --- a/tests/e2e/patched/test_sp.py +++ /dev/null @@ -1,481 +0,0 @@ -"""Tests for sequence parallelism functionality.""" - -# pylint: disable=redefined-outer-name,unused-argument - -import functools -import sys -from unittest.mock import MagicMock, patch - -import pytest -import torch -from accelerate.state import PartialState - -from axolotl.monkeypatch.ring_attn import ( - get_ring_attn_group, - register_ring_attn, - set_ring_attn_group, -) -from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism -from axolotl.utils.dict import DictDefault -from axolotl.utils.schemas.enums import RingAttnFunc -from axolotl.utils.schemas.trl import TRLConfig - - -@pytest.fixture -def partial_state(): - """Create a real PartialState instance for testing.""" - state = PartialState() - return state - - -@pytest.fixture(name="cfg") -def fixture_cfg(): - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "learning_rate": 1e-3, - "output_dir": "./model-out", - "sequence_len": 512, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "save_first_step": False, - } - ) - - return cfg - - -@pytest.fixture -def sequence_parallel_batch(): - """Create a test batch for sequence parallelism tests.""" - batch_size = 1 - seq_len = 8 - - # Create test tensors - input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len) - attention_mask = torch.ones(batch_size, seq_len) - position_ids = torch.arange(seq_len).expand(batch_size, seq_len) - labels = input_ids.clone() - - # Create test batch - batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "labels": labels, - } - - return batch - - -class TestRingAttention: - """Tests for the ring attention functionality.""" - - @patch("torch.distributed.get_rank") - @patch("torch.distributed.get_world_size") - def test_get_ring_attn_group_no_registration( - self, mock_world_size, mock_rank, partial_state - ): - """Test that get_ring_attn_group raises RuntimeError when no group has been registered.""" - # Setup mocks - mock_world_size.return_value = 4 - mock_rank.return_value = 0 - - # Verify that RuntimeError is raised when no group is registered - with pytest.raises( - RuntimeError, match="register_ring_attn\\(\\) not yet called" - ): - get_ring_attn_group() - - @patch("torch.distributed.new_group") - @patch("torch.distributed.get_rank") - @patch("torch.distributed.get_world_size") - def test_register_ring_attn( - self, mock_world_size, mock_rank, mock_new_group, partial_state - ): - """Test that ring attention groups are created correctly.""" - # Setup mocks - mock_world_size.return_value = 8 # 8 GPUs total - mock_rank.return_value = 3 # GPU #3 - mock_group = MagicMock() - mock_new_group.return_value = mock_group - - # Call register_ring_attn with size 4 - register_ring_attn( - sequence_parallel_degree=4, - heads_k_stride=1, - ring_attn_func=RingAttnFunc.VARLEN_LLAMA3, - ) - - # Verify the number of calls without examining the arguments - assert mock_new_group.call_count == 2 - - # Verify that new_group was called - mock_new_group.assert_called() - - # Clean up - set_ring_attn_group(None) - - -class TestConfigValidation: - """Tests for validating sequence parallelism configurations.""" - - @pytest.fixture(autouse=True) - def setup_mocks(self, monkeypatch): - """Set up mocks for all tests in this class.""" - # Mock the ring_flash_attn module - monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) - - @pytest.fixture - def base_cfg(self): - """Create a base configuration for testing.""" - return DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "learning_rate": 1e-3, - "output_dir": "./model-out", - "sequence_len": 512, - "special_tokens": {"pad_token": "<|endoftext|>"}, - } - ) - - @pytest.mark.parametrize( - "config_updates, expected_values, should_pass, error_msg", - [ - # Valid configuration - ( - {"sequence_parallel_degree": 2, "flash_attention": True}, - {"sequence_parallel_degree": 2, "flash_attention": True}, - True, - None, - ), - # Default sequence_parallel_degree - ({}, {"sequence_parallel_degree": 1}, True, None), - # Invalid: sequence_parallel_degree > 1 without flash_attention - ( - {"sequence_parallel_degree": 2, "flash_attention": False}, - None, - False, - "flash_attention: true must be set", - ), - # Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 - ( - { - "sequence_parallel_degree": 2, - "flash_attention": True, - "sample_packing": True, - "micro_batch_size": 2, - "pad_to_sequence_len": True, - }, - None, - False, - "micro_batch_size must be set to 1", - ), - # Valid: Basic GRPO config - ( - { - "sequence_parallel_degree": 2, - "flash_attention": True, - "micro_batch_size": 2, - "trl": {"use_liger_loss": True}, - }, - { - "sequence_parallel_degree": 2, - "flash_attention": True, - "micro_batch_size": 2, - "trl": TRLConfig(use_liger_loss=True), - }, - True, - "GRPO + SP + Liger not currently supported", - ), - # Invalid: GRPO config with Liger loss - ( - { - "rl": "grpo", - "sequence_parallel_degree": 2, - "flash_attention": True, - "micro_batch_size": 2, - "trl": {"use_liger_loss": True}, - }, - None, - False, - "GRPO + SP + Liger not currently supported", - ), - ], - ids=[ - "valid_config", - "default_sp_degree", - "without_flash_attention", - "sample_packing_with_large_batch", - "valid_grpo", - "grpo_with_liger_loss", - ], - ) - def test_sequence_parallel_config_validation( - self, base_cfg, config_updates, expected_values, should_pass, error_msg - ): - """Test various sequence parallelism configuration scenarios.""" - from axolotl.utils.schemas.config import AxolotlInputConfig - - # Apply updates to base config - cfg = base_cfg - cfg.update(config_updates) - - if should_pass: - # Should validate without errors - config = AxolotlInputConfig(**cfg) - - # Check expected values - for key, value in expected_values.items(): - assert getattr(config, key) == value - else: - # Should raise exception - with pytest.raises(ValueError) as excinfo: - AxolotlInputConfig(**cfg) - assert error_msg in str(excinfo.value) - - @pytest.mark.parametrize( - "ring_attn_func, sample_packing, expected_func", - [ - (None, True, RingAttnFunc.VARLEN_LLAMA3), - (None, False, RingAttnFunc.BATCH_RING), - ], - ids=["default_with_sample_packing", "default_without_sample_packing"], - ) - def test_ring_attn_func_validation( - self, base_cfg, ring_attn_func, sample_packing, expected_func - ): - """Test ring_attn_func validation and defaults.""" - from axolotl.utils.schemas.config import AxolotlInputConfig - - # Apply updates to base config - cfg = base_cfg | { - "sequence_parallel_degree": 2, - "flash_attention": True, - "sample_packing": sample_packing, - } - - if ring_attn_func is not None: - cfg["ring_attn_func"] = ring_attn_func - - # Should validate without errors - config = AxolotlInputConfig(**cfg) - - # Check ring_attn_func value - assert config.ring_attn_func.value == expected_func - - def test_invalid_ring_attn_func(self, base_cfg): - """Test that an invalid ring_attn_func is rejected.""" - from axolotl.utils.schemas.config import AxolotlInputConfig - - # Invalid configuration with invalid ring_attn_func - cfg = base_cfg | { - "sequence_parallel_degree": 2, - "flash_attention": True, - "ring_attn_func": "INVALID_FUNC", - } - - # Should raise ValidationError - with pytest.raises(ValueError) as excinfo: - AxolotlInputConfig(**cfg) - - # Verify error message - assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value) - - -class TestApplySequenceParallelism: - """Tests for the apply_sequence_parallelism function.""" - - @pytest.fixture(autouse=True) - def mock_distributed(self, monkeypatch): - """Mock torch.distributed functions for testing.""" - # Mock is_initialized to return True - monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) - - # Mock get_rank to return 0 by default - monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0) - - # Mock get_world_size to return 2 by default - monkeypatch.setattr( - torch.distributed, "get_world_size", lambda *args, **kwargs: 2 - ) - - # Mock the process group - monkeypatch.setattr( - "axolotl.monkeypatch.ring_attn.get_ring_attn_group", - MagicMock, - ) - - # Mock update_ring_attn_params - monkeypatch.setattr( - "axolotl.monkeypatch.ring_attn.update_ring_attn_params", - lambda **kwargs: None, - ) - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch): - """Test that function returns original batch when world size is 1.""" - mock_get_ring_attn_group.return_value = 0 - - result, _, _ = apply_sequence_parallelism( - batch=sequence_parallel_batch, - local_rank=0, - local_world_size=1, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Should return the original batch unchanged - assert result == sequence_parallel_batch - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch): - """Test BATCH_RING sharding for rank 0 in a 2-process group.""" - mock_get_ring_attn_group.return_value = 0 - - batch = sequence_parallel_batch - seq_len = batch["input_ids"].size(1) - - result, _, _ = apply_sequence_parallelism( - batch=batch, - local_rank=0, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Check that sequence dimension was sharded correctly - assert result["input_ids"].shape[1] == seq_len // 2 - assert result["attention_mask"].shape[1] == seq_len // 2 - - # Verify content: rank 0 should get the first half of the sequence - assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2]) - assert torch.equal( - result["position_ids"], batch["position_ids"][:, : seq_len // 2] - ) - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch): - """Test BATCH_RING sharding for rank 1 in a 2-process group.""" - mock_get_ring_attn_group.return_value = 0 - - batch = sequence_parallel_batch - seq_len = batch["input_ids"].size(1) - original_input_ids = batch["input_ids"].clone() - - result, _, _ = apply_sequence_parallelism( - batch=batch, - local_rank=1, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Verify content: rank 1 should get the second half of the sequence - assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :]) - - # TODO(djsaunde): add back once implemented. - # def test_batch_zigzag(self, sequence_parallel_batch): - # """Test BATCH_ZIGZAG sharding pattern.""" - # batch = sequence_parallel_batch - # original_input_ids = batch["input_ids"].clone() - # seq_len = batch["input_ids"].size(1) - - # # Test rank 0 - # result_rank0 = apply_sequence_parallelism( - # batch={k: v.clone() for k, v in batch.items()}, - # local_rank=0, - # local_world_size=2, - # ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, - # ) - - # # Test rank 1 - # result_rank1 = apply_sequence_parallelism( - # batch={k: v.clone() for k, v in batch.items()}, - # local_rank=1, - # local_world_size=2, - # ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, - # ) - - # # Checks for both ranks - # assert result_rank0["input_ids"].shape[1] == seq_len // 2 - # assert result_rank1["input_ids"].shape[1] == seq_len // 2 - - # # For a 2-rank system with 8 tokens, check specific zigzag pattern - # # Rank 0 should get chunks [0, 1] and [6, 7] - # # Rank 1 should get chunks [2, 3] and [4, 5] - # if seq_len == 8: - # # Create expected tensors for comparison - # rank0_expected = torch.cat( - # [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1 - # ) - - # rank1_expected = torch.cat( - # [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1 - # ) - - # assert torch.equal(result_rank0["input_ids"], rank0_expected) - # assert torch.equal(result_rank1["input_ids"], rank1_expected) - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_partial_application( - self, mock_get_ring_attn_group, sequence_parallel_batch - ): - """Test that we can create a partially applied version of the function.""" - mock_get_ring_attn_group.return_value = 0 - - batch = sequence_parallel_batch - original_input_ids = batch["input_ids"].clone() - - # Create a partially applied function - rank0_ring_parallel = functools.partial( - apply_sequence_parallelism, - local_rank=0, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Use the partially applied function - result, _, _ = rank0_ring_parallel(batch=batch) - - # Verify it works as expected - assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 - assert torch.equal( - result["input_ids"], - original_input_ids[:, : original_input_ids.shape[1] // 2], - ) - - def test_missing_position_ids(self, sequence_parallel_batch): - """Test handling of batch without position_ids.""" - # Create a batch without position_ids - batch = { - k: v for k, v in sequence_parallel_batch.items() if k != "position_ids" - } - original_input_ids = batch["input_ids"].clone() - - # This should run without error even though position_ids is missing - result, _, _ = apply_sequence_parallelism( - batch=batch, - local_rank=0, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Verification should pass - assert "position_ids" in result - assert result["input_ids"].shape[1] == result["position_ids"].shape[1] - assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py index 5061945b4..8fcffeb11 100644 --- a/tests/e2e/test_load_model.py +++ b/tests/e2e/test_load_model.py @@ -52,6 +52,8 @@ class TestLoadModelUtils: "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "tensor_parallel_size": 1, + "context_parallel_size": 1, } ) self.model_loader = ( # pylint: disable=attribute-defined-outside-init diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 696e3b03c..5931fe148 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -142,6 +142,10 @@ def is_hopper(): return compute_capability == (9, 0) +def require_hopper(test_case): + return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case) + + def check_tensorboard( temp_run_dir: str, tag: str, lt_val: float, assertion_err: str ) -> None: diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 7313a8267..def7672b9 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -171,3 +171,44 @@ class TestModelsUtils: message_property_mappings={"content": "different_content"}, ) assert "Conflicting message content fields" in str(exc_info.value) + + @pytest.mark.parametrize( + "world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected", + [ + (16, 2, 2, 2, 2, True, (2, 2, 2, 2)), + (16, 1, 1, None, None, True, (0, 0, 16, 1)), + (16, 2, 2, 2, None, True, (2, 2, 2, 2)), + (16, 2, 2, None, 2, True, (2, 2, 2, 2)), + (16, 1, 1, None, 2, True, (0, 0, 8, 2)), + (2, 1, 1, None, None, True, (0, 0, 2, 1)), + ], + ) + def test_get_parallel_config_kwargs( + self, + world_size, + tensor_parallel_size, + context_parallel_size, + dp_shard_size, + dp_replicate_size, + is_fsdp, + expected, + ): + res = ( + ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access + world_size, + tensor_parallel_size, + context_parallel_size, + dp_shard_size, + dp_replicate_size, + is_fsdp, + ) + ) + + if expected[0] > 1: + assert res["tp_size"] == expected[0] + if expected[1] > 1: + assert res["cp_size"] == expected[1] + if expected[2] > 1: + assert res["dp_shard_size"] == expected[2] + if expected[3] > 1: + assert res["dp_replicate_size"] == expected[3] diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py index 67f4a5cf9..5b461a113 100644 --- a/tests/utils/schemas/validation/test_fsdp.py +++ b/tests/utils/schemas/validation/test_fsdp.py @@ -26,32 +26,6 @@ class TestFSDPValidation: assert cfg.fsdp_version == 2 assert cfg.fsdp_config.fsdp_version is None - def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg): - cfg = min_base_cfg | DictDefault( - fsdp_config={ - "fsdp_state_dict_type": "SHARDED_STATE_DICT", - }, - save_safetensors=True, - ) - with pytest.raises( - ValueError, - match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors", - ): - validate_config(cfg) - - # test w/o prefix too - cfg = min_base_cfg | DictDefault( - fsdp_config={ - "state_dict_type": "SHARDED_STATE_DICT", - }, - save_safetensors=True, - ) - with pytest.raises( - ValueError, - match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors", - ): - validate_config(cfg) - def test_fsdp_offload_w_8bit_optim(self, min_base_cfg): cfg = min_base_cfg | DictDefault( fsdp_config={