From 23f0c51d88aa1d791d8619b146d92892c27d7ade Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 21 Mar 2025 12:43:55 -0400 Subject: [PATCH] Sequence parallelism (#2412) * adding easy_context as integration for now * progress on ring attn impl * progress on ring attn impl * cleanup * remove errant file * fix req * removing unused code * updates * pytest * update * updates * fixes * precommit fixes * working multi-group SP * fixing sample packing * remove debug logs and simplify * eval dataloader and sampler changes * removing some obvious comments * update config.qmd and rename option * scoping down problematic import * another import scoping change * pernicious Fire CLI bugfix * isolate cli tests * actually isolate CLI tests * gracefully handle no ring-flash-attn * fix * fix * move ring flash attn to extras with flash-attn (#2414) * removing flash-attn from requirements.txt (in setup.py extras already) * rename file, delete another * using field validator instead of model validator * test fix * sampler / dataloader refactor * non-seq2se1 collator fix * removing print statement * bugfix * add SP doc, review comments * small changes * review comments, docstrings * refactors, SP mixin * small updates * fix tests * precommit * precommit --------- Co-authored-by: Wing Lian Co-authored-by: Dan Saunders --- .github/workflows/tests.yml | 6 +- cicd/Dockerfile.jinja | 4 +- cicd/cicd.sh | 5 +- docs/config.qmd | 11 + docs/sequence_parallelism.qmd | 90 ++ requirements.txt | 2 +- setup.py | 12 +- src/axolotl/cli/train.py | 11 +- src/axolotl/core/trainer_builder.py | 19 +- src/axolotl/core/trainers/__init__.py | 18 + src/axolotl/core/trainers/base.py | 821 ++++++------------ src/axolotl/core/trainers/dpo/trainer.py | 12 +- src/axolotl/core/trainers/mamba.py | 32 + src/axolotl/core/trainers/mixins/__init__.py | 8 + src/axolotl/core/trainers/mixins/optimizer.py | 201 +++++ src/axolotl/core/trainers/mixins/scheduler.py | 113 +++ .../core/trainers/mixins/sequence_parallel.py | 131 +++ src/axolotl/core/trainers/relora.py | 43 + src/axolotl/core/trainers/trl.py | 65 +- src/axolotl/core/trainers/utils.py | 33 + src/axolotl/core/training_args.py | 9 +- .../monkeypatch/attention/ring_attn.py | 89 ++ src/axolotl/train.py | 5 +- src/axolotl/utils/collators/batching.py | 96 +- src/axolotl/utils/config/__init__.py | 3 + src/axolotl/utils/models.py | 82 +- src/axolotl/utils/samplers/multipack.py | 4 +- src/axolotl/utils/schemas/config.py | 27 +- src/axolotl/utils/trainer.py | 18 +- tests/e2e/patched/test_sp.py | 207 +++++ tests/test_exact_deduplication.py | 3 + 31 files changed, 1532 insertions(+), 648 deletions(-) create mode 100644 docs/sequence_parallelism.qmd create mode 100644 src/axolotl/core/trainers/mamba.py create mode 100644 src/axolotl/core/trainers/mixins/__init__.py create mode 100644 src/axolotl/core/trainers/mixins/optimizer.py create mode 100644 src/axolotl/core/trainers/mixins/scheduler.py create mode 100644 src/axolotl/core/trainers/mixins/sequence_parallel.py create mode 100644 src/axolotl/core/trainers/relora.py create mode 100644 src/axolotl/core/trainers/utils.py create mode 100644 src/axolotl/monkeypatch/attention/ring_attn.py create mode 100644 tests/e2e/patched/test_sp.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 32bb42821..66d95b3d4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -98,8 +98,9 @@ jobs: - name: Run tests run: | - pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ + pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v tests/patched/ + pytest -v tests/cli/ - name: cleanup pip cache run: | @@ -172,8 +173,9 @@ jobs: - name: Run tests run: | - pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ + pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v tests/patched/ + pytest -v tests/cli/ - name: cleanup pip cache run: | diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index b212a0065..6988e092b 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ RUN pip install packaging==23.2 setuptools==75.8.0 RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ + pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ fi RUN python scripts/unsloth_install.py | sh diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 8097e9c56..1d9ea7fbe 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -3,9 +3,10 @@ set -e python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" -pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ +pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ -pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ +pytest -v --durations=10 /workspace/axolotl/tests/cli +pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/ diff --git a/docs/config.qmd b/docs/config.qmd index ea7ea2293..9946b5865 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -32,6 +32,9 @@ tokenizer_legacy: resize_token_embeddings_to_32x: # Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink. shrink_embeddings: +# Whether to load the model with randomly initialized weights. Useful for +# pre-training a model from scratch or debugging purposes. +random_init_weights: # (Internal use only) # Used to identify which the model is based on @@ -617,6 +620,14 @@ ddp_timeout: ddp_bucket_cap_mb: ddp_broadcast_buffers: +# Sequence parallelism +# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. +# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. +# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized +# subsequences, or set to 4 to split into four equal-sized subsequences. +# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details. +sequence_parallel_degree: + # Path to torch distx for optim 'adamw_anyprecision' torchdistx_path: diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd new file mode 100644 index 000000000..cb297c0e0 --- /dev/null +++ b/docs/sequence_parallelism.qmd @@ -0,0 +1,90 @@ +--- +title: Sequence Parallelism +description: Train with long sequences split across multiple GPUs. +--- + +# Sequence Parallelism + +Sequence parallelism is a technique that splits sequences across multiple GPUs, +allowing you to train with very long sequences that wouldn't fit on a single GPU. Each +GPU processes a different portion of the sequence, and the results are aggregated +through a ring communication pattern. + +## When to Use Sequence Parallelism + +Use sequence parallelism when: + +- You need to train with sequence lengths that don't fit into a single GPU's memory +- You have multiple GPUs available +- You're experiencing OOM (Out Of Memory) errors with long sequences + +## Configuration + +To enable sequence parallelism, add the following to your configuration file: + +```yaml +# Set to a divisor (> 1) of the number of GPUs available +sequence_parallel_degree: 4 # Split sequences across 4 GPUs +``` + +The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: + +- With 8 GPUs, valid values would be 2, 4, or 8 +- With 4 GPUs, valid values would be 2 or 4 + +## Implementation Details + +When sequence parallelism is enabled: + +1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group +2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids +3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences +4. The trainer uses special ring communication patterns for attention operations + +## Requirements + +To use sequence parallelism, you need: + +- Multiple GPUs (at least 2) +- The `ring-flash-attn` package. Install with: + - `pip install axolotl[ring-flash-attn]` (preferred) + - `pip install ring-flash-attn>=0.1.4` + +## Limitations + +- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML) +- May have a small performance overhead due to communication between GPUs + +## Example + +```yaml +# Example config with sequence parallelism +base_model: meta-llama/Llama-3-8B-Instruct +sequence_len: 8192 +sequence_parallel_degree: 2 # Split each sequence into 4 parts +flash_attention: true # Required with sequence parallelism +... +``` + +This will train the Llama 3 8B model with 8K context length, with each sequence split +into 2 subsequences of length 4096 across 2 GPUs. + +## Sample Packing with Sequence Parallelism + +Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together: + +1. Samples are first packed together +2. The packed sequences are then divided across GPUs in the sequence parallel group +3. Position IDs are automatically adjusted to maintain proper relative positions + +## Effect on Batch Size + +When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: + +- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) +- The number of batches processed per step decreases + +For example: +- With 8 GPUs and no sequence parallelism: 8 different batches processed per step +- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) +- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 diff --git a/requirements.txt b/requirements.txt index 495f43af6..c8465d23f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ bitsandbytes==0.45.3 triton>=3.0.0 mamba-ssm==1.2.0.post1 -flash-attn==2.7.4.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 liger-kernel==0.5.3 @@ -36,6 +35,7 @@ einops colorama numba numpy>=1.24.4,<=2.0.1 + # qlora things evaluate==0.4.1 scipy diff --git a/setup.py b/setup.py index c4ffcdaeb..8b2f1b2a5 100644 --- a/setup.py +++ b/setup.py @@ -17,11 +17,7 @@ def parse_requirements(): lines = [r.strip() for r in requirements_file.readlines()] for line in lines: is_extras = ( - "flash-attn" in line - or "flash-attention" in line - or "deepspeed" in line - or "mamba-ssm" in line - or "lion-pytorch" in line + "deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line ) if line.startswith("--extra-index-url"): # Handle custom index URLs @@ -39,7 +35,6 @@ def parse_requirements(): "bitsandbytes", "triton", "mamba-ssm", - "flash-attn", "xformers", "autoawq", "liger-kernel", @@ -124,9 +119,8 @@ setup( ], }, extras_require={ - "flash-attn": [ - "flash-attn==2.7.4.post1", - ], + "flash-attn": ["flash-attn==2.7.4.post1"], + "ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"], "deepspeed": [ "deepspeed==0.16.4", "deepspeed-kernels", diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index e991105e6..6cc7c7701 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: +def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): """ Trains a `transformers` model by first loading the dataset(s) specified in the `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin @@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + del model, tokenizer, trainer + plugin_manager = PluginManager.get_instance() - - del model - del tokenizer - del trainer - plugin_manager.post_train_unload(cfg) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): """ Parses `axolotl` config, CLI args, and calls `do_train`. diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 83bfb1c83..b151be8fa 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -36,7 +36,7 @@ from transformers import ( from transformers.training_args import OptimizerNames from trl.trainer.utils import RewardDataCollatorWithPadding -from axolotl.core.trainers.base import ( +from axolotl.core.trainers import ( AxolotlCPOTrainer, AxolotlKTOTrainer, AxolotlMambaTrainer, @@ -762,6 +762,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.kd_top_k_before_softmax ) + training_arguments_kwargs["sequence_parallel_degree"] = ( + self.cfg.sequence_parallel_degree + ) + if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig elif self.cfg.process_reward_model: @@ -845,9 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): if training_args.pretraining: - if self.cfg.pretraining_sample_concatenation is False: - return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) - if self.cfg.micro_batch_size > 1: + if ( + self.cfg.pretraining_sample_concatenation is False + or self.cfg.micro_batch_size > 1 + ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None @@ -875,9 +880,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if "max_length" in kwargs: kwargs.pop("max_length") elif use_batch_sampler_collator: - if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - collator = V2BatchSamplerDataCollatorForSeq2Seq - elif ( + if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( self.cfg.model_config_type in ["llama"] and self.cfg.flash_attention is not True ): @@ -908,6 +911,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" + if issubclass(collator, DataCollatorForSeq2Seq): + kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree return collator( *collator_args, diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index e69de29bb..32a889af9 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -0,0 +1,18 @@ +"""Init for axolotl.core.trainers""" + +# pylint: disable=unused-import +# flake8: noqa + +from .base import AxolotlTrainer +from .dpo.trainer import AxolotlDPOTrainer +from .grpo.trainer import AxolotlGRPOTrainer +from .mamba import AxolotlMambaTrainer +from .relora import ReLoRATrainer +from .trl import ( + AxolotlCPOTrainer, + AxolotlKTOTrainer, + AxolotlORPOTrainer, + AxolotlPRMTrainer, + AxolotlRewardTrainer, + TRLPPOTrainer, +) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 6570db967..9267dd040 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -1,365 +1,47 @@ -""" -module for customized trainers -""" +"""Module for customized trainers""" + +# pylint: disable=too-many-lines from __future__ import annotations -# pylint: disable=too-many-lines import logging import os from collections import defaultdict from functools import wraps -from typing import Dict, Literal, Optional +from typing import Any, Literal +import datasets import torch from datasets import Dataset -from peft.optimizers import create_loraplus_optimizer from torch import nn -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import ( + BatchSampler, + DataLoader, + RandomSampler, + Sampler, + SequentialSampler, +) from transformers import Trainer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker -from transformers.utils import is_sagemaker_mp_enabled -from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length +from typing_extensions import override -from axolotl.integrations.base import BaseOptimizerFactory -from axolotl.monkeypatch.relora import ReLoRAScheduler -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.schedulers import ( - RexLR, - get_cosine_schedule_with_min_lr, - get_cosine_schedule_with_quadratic_warmup, - get_cosine_schedule_with_warmup_decay_constant, +from axolotl.core.trainers.mixins import ( + OptimizerMixin, + SchedulerMixin, + SequenceParallelMixin, ) +from axolotl.core.trainers.utils import ( + sanitize_kwargs_for_ds_tagging, + sanitize_kwargs_for_tagging, +) +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp - -LOG = logging.getLogger("axolotl.core.trainer_builder") +LOG = logging.getLogger(__name__) -def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): - if isinstance(tag_names, str): - tag_names = [tag_names] - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - - return kwargs - - -def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): - if isinstance(dataset_tags, str): - dataset_tags = [dataset_tags] - - if (dataset_tags is not None) and (kwargs is not None): - if "dataset_tags" not in kwargs: - kwargs["dataset_tags"] = dataset_tags - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): - kwargs["dataset_tags"].extend(dataset_tags) - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): - dataset_tags.append(kwargs["dataset_tags"]) - kwargs["dataset_tags"] = dataset_tags - - return kwargs - - -class SchedulerMixin(Trainer): - """ - Mixin class for scheduler setup in CausalTrainer. - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ) - - use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" - and self.args.cosine_min_lr_ratio is not None - ) - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if self.args.alternate_lr_scheduler_type == "one_cycle": - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - extra_lr_kwargs = {} - if "pct_start" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["pct_start"] = pct_start - if "anneal_strategy" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["anneal_strategy"] = "cos" - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - **extra_lr_kwargs, - **self.args.lr_scheduler_kwargs, - ) - elif self.args.alternate_lr_scheduler_type == "rex": - if use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - - self.lr_scheduler = RexLR( - optimizer=optimizer, - max_lr=self.args.learning_rate, - min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), - total_steps=num_training_steps, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - ) - elif use_cosine_quadratic: - if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - constant_lr_ratio=self.args.cosine_constant_lr_ratio, - ) - elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - ) - else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) - else: - if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") - - if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - - return self.lr_scheduler - - -class OptimizerMixin(Trainer): - """ - Mixin class for shared handling of building custom optimizers - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def create_optimizer_grouped_parameters( - self, opt_model, optimizer_kwargs - ) -> list[dict]: - decay_parameters = self.get_decay_parameter_names(opt_model) - params: dict = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - lr_groups_lookup = {} - lr_groups_learning_rates = {} - if self.args.lr_groups: - for lr_group in self.args.lr_groups: - group_name = lr_group["name"] - group_modules = lr_group["modules"] - for module in group_modules: - lr_groups_lookup[module] = group_name - lr_groups_learning_rates[group_name] = lr_group["lr"] - params[f"to_weight_decay_{group_name}"] = {} - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - lr_group_modules = [ - group_modules - for group_modules in lr_groups_lookup - if group_modules in name - ] - if lr_groups_lookup and any(lr_group_modules): - lr_group_module = lr_group_modules[0] - group_name = lr_groups_lookup[lr_group_module] - params[f"to_weight_decay_{group_name}"][name] = param - else: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) - for group_name, group_lr in lr_groups_learning_rates.items(): - if params[f"to_weight_decay_{group_name}"]: - optimizer_grouped_parameters.append( - { - "params": list( - params[f"to_weight_decay_{group_name}"].values() - ), - "weight_decay": self.args.weight_decay, - "lr": group_lr, - } - ) - - return optimizer_grouped_parameters - - def create_optimizer(self): - if ( - self.args.loraplus_lr_ratio is None - and self.args.embedding_lr_scale is None - and self.args.embedding_lr is None - and self.args.lr_groups is None - and self.optimizer_cls_and_kwargs is None - ): - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - - if ( - not self.optimizer - and self.optimizer_cls_and_kwargs is not None - and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) - ): - optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs - self.optimizer = optimizer_factory_cls()( - opt_model, self.args, **optimizer_kwargs - ) - - if not self.optimizer: - if self.optimizer_cls_and_kwargs is not None: - optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs - else: - optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs( - self.args, opt_model - ) - - optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( - opt_model, optimizer_kwargs - ) - - if self.args.loraplus_lr_ratio is not None: - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", 1e-6 - ) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - else: - # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` - # e.g. for GaLore optimizer. - if "params" in optimizer_kwargs: - optimizer_grouped_parameters = optimizer_kwargs.pop("params") - - # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` - # e.g. for LOMO optimizer. - if "model" in optimizer_kwargs: - optimizer_grouped_parameters = optimizer_kwargs.pop("model") - - # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` - # to avoid arguments conflicts. - if "optimizer_dict" in optimizer_kwargs: - optimizer_grouped_parameters = optimizer_kwargs.pop( - "optimizer_dict" - ) - - self.optimizer = optimizer_cls( - optimizer_grouped_parameters, **optimizer_kwargs - ) - - if optimizer_cls.__name__ == "Adam8bit": - import bitsandbytes - - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum( - { - p.data_ptr(): p.numel() for p in module.parameters() - }.values() - ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") - manager.register_module_override( - module, "weight", {"optim_bits": 32} - ) - LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - -class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): - """ - Extend the base Trainer for axolotl helpers - """ +class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer): + """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] tag_names = ["axolotl"] @@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): self.eval_data_collator = eval_data_collator self.dataset_tags = dataset_tags self._signature_columns = None # workaround for pylint + super().__init__(*_args, **kwargs) + self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + # Initialize sequence parallelism if enabled + if self.args.sequence_parallel_degree > 1: + self._setup_sequence_parallel() + def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access @@ -394,142 +82,247 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and not self.args.pretraining: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_train_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - train_batch_size = ( - self.state.train_batch_size or self.args.per_device_train_batch_size - ) - batch_max_len = train_batch_size * self.args.max_seq_length + def _create_multipack_sampler( + self, base_sampler: Sampler, dataset: Dataset + ) -> MultipackBatchSampler: + """ + Helper method to create a `MultipackBatchSampler` for multipacking sequences + for training. - if self.args.curriculum_sampling: - sampler = SequentialSampler(self.train_dataset) - else: - sampler = RandomSampler(self.train_dataset) + Args: + base_sampler: Sampler to wrap with `MultipackBatchSampler`. + dataset: Dataset to sample from. - return MultipackBatchSampler( - sampler, - lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, + Returns: + Multipack (sample packing) batch sampler. + """ + if self.args.multipack_real_batches: + batch_size = self.args.per_device_train_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size ) - if self.args.curriculum_sampling: - return SequentialSampler(self.train_dataset) - return super()._get_train_sampler() + batch_max_len = train_batch_size * self.args.max_seq_length - def _get_eval_sampler( - self, eval_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_eval_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - batch_max_len = ( - self.args.per_device_eval_batch_size * self.args.max_seq_length - ) - return MultipackBatchSampler( - SequentialSampler(eval_dataset), - lengths=get_dataset_lengths(self.eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, + return MultipackBatchSampler( + base_sampler, + lengths=get_dataset_lengths(dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + drop_last=True, + ) + + def _get_train_sampler(self) -> Sampler | None: + """ + Helper method to get the sampler for training. Handles cases for sequence + parallelism, sample packing, and curriculum sampling (sequential). + + Returns: + If the dataset is non-empty, a sampler is returned, the type of which + depends on the passed training args. + """ + use_sample_packing = self.args.sample_packing and not self.args.pretraining + + # Determine the base sampler first + if self.args.sequence_parallel_degree > 1: + base_sampler = self._sp_get_train_sampler(self.train_dataset) + elif self.args.curriculum_sampling: + base_sampler = SequentialSampler(self.train_dataset) + elif use_sample_packing: + base_sampler = RandomSampler(self.train_dataset) + else: + # Default to parent class implementation for standard random sampling + return super()._get_train_sampler() + + # Apply multipack wrapper if needed + if use_sample_packing: + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=self.train_dataset, ) - return super()._get_eval_sampler(eval_dataset) - def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and not self.args.pretraining: - train_dataset = self.train_dataset - if "length" in train_dataset.features.keys(): - train_dataset = train_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = ( - self.args.dataloader_prefetch_factor - ) + return base_sampler - sampler = self._get_train_sampler() + def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: + """ + Helper method to get the sampler for evaluation. Handles sequence parallelism + and sample packing cases. + + Returns: + If the dataset is non-empty, a sampler is returned, the type of which + depends on the passed training args. + """ + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + # Multipacking enabled if training is enabled and eval is not explicitly disabled + use_multipack = ( + self.args.sample_packing and self.args.eval_sample_packing is not False + ) + + # Determine the base sampler + if self.args.sequence_parallel_degree > 1: + base_sampler = self._sp_get_eval_sampler(eval_dataset) + elif use_multipack: + base_sampler = SequentialSampler(eval_dataset) + else: + return super()._get_eval_sampler(eval_dataset) + + # Apply multipack wrapper if needed + if use_multipack: + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=eval_dataset, + ) + + return base_sampler + + def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): + """Create common dataloader parameters for train or eval.""" + batch_size = custom_batch_size or ( + self.args.eval_batch_size if is_eval else self._train_batch_size + ) + + params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + # Add persistent workers only for training + if not is_eval and hasattr(self.args, "dataloader_persistent_workers"): + params["persistent_workers"] = self.args.dataloader_persistent_workers + + # Add prefetch factor if specified + if self.args.dataloader_prefetch_factor: + params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return params + + def _prepare_dataloader( + self, dataset, sampler, is_eval=False, custom_batch_size=None + ): + """Prepare a dataloader with the given dataset and sampler.""" + # Get base parameters + dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size) + + # Add sampler configuration + if not isinstance(dataset, torch.utils.data.IterableDataset): if isinstance(sampler, BatchSampler): + # batch_size and batch_sampler are mutually exclusive dataloader_params["batch_sampler"] = sampler del dataloader_params["batch_size"] else: dataloader_params["sampler"] = sampler dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker + if not is_eval: + dataloader_params["worker_init_fn"] = seed_worker + + # Create the dataloader + dataloader = DataLoader(dataset, **dataloader_params) + + if self.args.sample_packing and ( + (not is_eval and not self.args.pretraining) + or (is_eval and self.args.eval_sample_packing is not False) + ): self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - return super().get_train_dataloader() - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + # Return unprepared dataloader if using sequence parallelism + if self.args.sequence_parallel_degree > 1: + return dataloader + + # Otherwise prepare with accelerator + return self.accelerator.prepare_data_loader(dataloader) + + def get_train_dataloader(self) -> DataLoader: + """Get dataloader for training""" + train_dataset = self.train_dataset + data_collator = self.data_collator # type: ignore + + # Handle dataset preprocessing + if isinstance(train_dataset, datasets.Dataset): + if self.args.sample_packing and not self.args.pretraining: + train_dataset = train_dataset.remove_columns(["length"]) + if not self.args.sample_packing or self.args.pretraining: + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init + data_collator, + description="training", + ) + + # Get sampler and create dataloader + sampler = self._get_train_sampler() + return self._prepare_dataloader(train_dataset, sampler, is_eval=False) + + def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader: + """Get dataloader for evaluation""" + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + # Handle special case: sample packing is enabled but eval_sample_packing is False if self.args.sample_packing and self.args.eval_sample_packing is False: self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator ) - if eval_dataset: + if "length" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns(["length"]) dataloader = super().get_eval_dataloader(eval_dataset) self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.train_data_collator ) + return dataloader - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset + # Handle sample packing or sequence parallelism + if ( + self.args.sample_packing + and self.args.eval_sample_packing is not False + or self.args.sequence_parallel_degree > 1 + ): + # Get appropriate data collator + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.eval_data_collator + if hasattr(self, "eval_data_collator") and self.eval_data_collator + else self.data_collator + ) + if "length" in eval_dataset.column_names: + eval_dataset = eval_dataset.remove_columns(["length"]) + + # Handle dataset preprocessing for SP + if self.args.sequence_parallel_degree > 1: + if isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns( + eval_dataset, description="evaluation" + ) + else: + self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init + self.data_collator, description="evaluation" + ) + + # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise + batch_size = ( + self.args.eval_batch_size + if self.args.sample_packing + else self.args.per_device_eval_batch_size + ) + sampler = self._get_eval_sampler(eval_dataset) + dataloader = self._prepare_dataloader( + eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size ) - eval_sampler = self._get_eval_sampler(eval_dataset) - eval_dataset = eval_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = ( - self.args.dataloader_prefetch_factor - ) - - if isinstance(eval_sampler, BatchSampler): - dataloader_params["batch_sampler"] = eval_sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = eval_sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(eval_dataset, **dataloader_params) - ) + return dataloader return super().get_eval_dataloader(eval_dataset) def _get_bench_sampler( self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: + ) -> torch.utils.data.Sampler | None: if self.args.world_size <= 1: return SequentialSampler(bench_dataset) return None @@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return DataLoader(bench_dataset, **dataloader_params) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + @override def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): @@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) + return super().compute_loss( model, inputs, @@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = _sanitize_kwargs_for_ds_tagging( + kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return res - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: """ Log `logs` on the various objects watching training, including stored metrics. Args: - logs (`Dict[str, float]`): - The values to log. - start_time (`Optional[float]`): - The start of training. + logs: The values to log. + start_time: The start of training. """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" @@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return super().log(logs, start_time) def store_metrics( - self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" ) -> None: for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) @@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) - -class AxolotlMambaTrainer(AxolotlTrainer): - """ - Mamba specific trainer to handle loss calculation - """ - - tag_names = ["axolotl", "mamba"] - - def compute_loss( + def training_step( self, - model, - inputs, - return_outputs=False, # pylint: disable=unused-argument - num_items_in_batch=None, # pylint: disable=unused-argument - ): - input_ids = inputs.pop("input_ids") - lm_logits = model(input_ids).logits + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + num_items_in_batch: int | None = None, + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. Overrides the + `transformers.trainer.Trainer` method to handle sequence parallelism if + enabled. - labels = input_ids.to(lm_logits.device) - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() + Args: + model: Model to perform training step for. + inputs: Dictionary mapping. + """ + # Set up sequence parallelism for this step if enabled + if self.args.sequence_parallel_degree > 1: + self._update_ring_flash_attn_params(inputs) - loss_fct = torch.nn.CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) + # Proceed with normal training step + loss = super().training_step(model, inputs, num_items_in_batch) - return lm_loss - - -class ReLoRATrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - anneal_steps = ( - self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 - ) - self.lr_scheduler = ReLoRAScheduler( - optimizer, - lr_scheduler, - self.args.relora_steps, - anneal_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler - - return self.lr_scheduler - - -class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): - """ - Extend the base ORPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "orpo"] - - -class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): - """ - Extend the base KTOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "kto"] - - -class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): - """ - Extend the base CPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "cpo"] - - -class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): - """ - Extend the base RewardTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "reward"] - - -class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): - """ - Extend the base trl.PRMTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "prm"] + return loss diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 38b657260..9eb870a3a 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -13,10 +13,10 @@ from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer -from axolotl.core.trainers.base import ( - SchedulerMixin, - _sanitize_kwargs_for_ds_tagging, - _sanitize_kwargs_for_tagging, +from axolotl.core.trainers.mixins import SchedulerMixin +from axolotl.core.trainers.utils import ( + sanitize_kwargs_for_ds_tagging, + sanitize_kwargs_for_tagging, ) if is_sagemaker_mp_enabled(): @@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = _sanitize_kwargs_for_ds_tagging( + kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) diff --git a/src/axolotl/core/trainers/mamba.py b/src/axolotl/core/trainers/mamba.py new file mode 100644 index 000000000..38792e389 --- /dev/null +++ b/src/axolotl/core/trainers/mamba.py @@ -0,0 +1,32 @@ +"""Module for mamba trainer""" + +import torch + +from axolotl.core.trainers.base import AxolotlTrainer + + +class AxolotlMambaTrainer(AxolotlTrainer): + """Mamba specific trainer to handle loss calculation""" + + tag_names = ["axolotl", "mamba"] + + def compute_loss( + self, + model, + inputs, + return_outputs=False, # pylint: disable=unused-argument + num_items_in_batch=None, # pylint: disable=unused-argument + ): + input_ids = inputs.pop("input_ids") + lm_logits = model(input_ids).logits + + labels = input_ids.to(lm_logits.device) + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + return lm_loss diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py new file mode 100644 index 000000000..12c8277fc --- /dev/null +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -0,0 +1,8 @@ +"""Init for axolotl.core.trainers.mixins""" + +# pylint: disable=unused-import +# flake8: noqa + +from .optimizer import OptimizerMixin +from .scheduler import SchedulerMixin +from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py new file mode 100644 index 000000000..bde58aa1d --- /dev/null +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -0,0 +1,201 @@ +"""Module for Axolotl trainer optimizer mixin""" + +import logging + +from peft.optimizers import create_loraplus_optimizer +from torch import nn +from transformers.trainer import Trainer +from transformers.utils import is_sagemaker_mp_enabled + +from axolotl.integrations.base import BaseOptimizerFactory + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + +LOG = logging.getLogger(__name__) + + +class OptimizerMixin(Trainer): + """Mixin class for shared handling of building custom optimizers""" + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def create_optimizer_grouped_parameters( + self, opt_model, optimizer_kwargs + ) -> list[dict]: + decay_parameters = self.get_decay_parameter_names(opt_model) + params: dict = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + lr_groups_lookup = {} + lr_groups_learning_rates = {} + if self.args.lr_groups: + for lr_group in self.args.lr_groups: + group_name = lr_group["name"] + group_modules = lr_group["modules"] + for module in group_modules: + lr_groups_lookup[module] = group_name + lr_groups_learning_rates[group_name] = lr_group["lr"] + params[f"to_weight_decay_{group_name}"] = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + lr_group_modules = [ + group_modules + for group_modules in lr_groups_lookup + if group_modules in name + ] + if lr_groups_lookup and any(lr_group_modules): + lr_group_module = lr_group_modules[0] + group_name = lr_groups_lookup[lr_group_module] + params[f"to_weight_decay_{group_name}"][name] = param + else: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + for group_name, group_lr in lr_groups_learning_rates.items(): + if params[f"to_weight_decay_{group_name}"]: + optimizer_grouped_parameters.append( + { + "params": list( + params[f"to_weight_decay_{group_name}"].values() + ), + "weight_decay": self.args.weight_decay, + "lr": group_lr, + } + ) + + return optimizer_grouped_parameters + + def create_optimizer(self): + if ( + self.args.loraplus_lr_ratio is None + and self.args.embedding_lr_scale is None + and self.args.embedding_lr is None + and self.args.lr_groups is None + and self.optimizer_cls_and_kwargs is None + ): + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if ( + not self.optimizer + and self.optimizer_cls_and_kwargs is not None + and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) + ): + optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + self.optimizer = optimizer_factory_cls()( + opt_model, self.args, **optimizer_kwargs + ) + + if not self.optimizer: + if self.optimizer_cls_and_kwargs is not None: + optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + else: + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs( + self.args, opt_model + ) + + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( + opt_model, optimizer_kwargs + ) + + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", 1e-6 + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + else: + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "optimizer_dict" + ) + + self.optimizer = optimizer_cls( + optimizer_grouped_parameters, **optimizer_kwargs + ) + + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum( + { + p.data_ptr(): p.numel() for p in module.parameters() + }.values() + ) + LOG.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override( + module, "weight", {"optim_bits": 32} + ) + LOG.debug(f"bitsandbytes: will optimize {module} in fp32") + LOG.info(f"skipped: {skipped/2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py new file mode 100644 index 000000000..b0a5ee895 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -0,0 +1,113 @@ +"""Module for Axolotl trainer scheduler mixin""" + +import logging + +import torch +from torch.optim.lr_scheduler import OneCycleLR +from transformers.trainer import Trainer + +from axolotl.utils.schedulers import ( + RexLR, + get_cosine_schedule_with_min_lr, + get_cosine_schedule_with_quadratic_warmup, + get_cosine_schedule_with_warmup_decay_constant, +) + +LOG = logging.getLogger(__name__) + + +class SchedulerMixin(Trainer): + """ + Mixin class for scheduler setup in CausalTrainer. + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if self.args.alternate_lr_scheduler_type == "one_cycle": + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + extra_lr_kwargs = {} + if "pct_start" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["pct_start"] = pct_start + if "anneal_strategy" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["anneal_strategy"] = "cos" + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + **extra_lr_kwargs, + **self.args.lr_scheduler_kwargs, + ) + elif self.args.alternate_lr_scheduler_type == "rex": + if use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + + self.lr_scheduler = RexLR( + optimizer=optimizer, + max_lr=self.args.learning_rate, + min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), + total_steps=num_training_steps, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + ) + elif use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) + else: + return super().create_scheduler(num_training_steps, optimizer=optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + + return self.lr_scheduler diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py new file mode 100644 index 000000000..f52c044b6 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -0,0 +1,131 @@ +"""Module for Axolotl trainer sequence parallelism mixin""" + +import logging +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from datasets import Dataset +from torch.utils.data import DistributedSampler, Sampler + +from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + +LOG = logging.getLogger(__name__) + +try: + from ring_flash_attn import update_ring_flash_attn_params +except ImportError: + # We pass silently here, but raise an ImportError in our Axolotl config validation + # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. + pass + + +class SequenceParallelMixin: + """ + Mixin class for sequence parallelism support in trainers. + + This mixin provides functionality for handling sequence parallelism, + including creating appropriate samplers, managing data partitioning, + and updating ring flash attention parameters during training. + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def _setup_sequence_parallel(self): + """Set up sequence parallelism environment.""" + self.ring_attn_group = get_ring_attn_group() + + def _create_sequence_parallel_sampler( + self, + dataset: Dataset, + shuffle: bool = True, + is_eval: bool = False, + ) -> DistributedSampler: + """ + Helper method to create sampler for sequence parallelism (SP). + + We create a distributed sampler with rank equal to the SP group ID, which + means that all ranks in the SP group receive the same sample / set of samples + per training step. We also set the number of replicas equal to the number of + SP groups, which is a bit of a hack / unintended use, but works! + + Args: + dataset: Dataset to sample from. + shuffle: Whether to shuffle the dataset. + is_eval: Whether we are creating a sampler for evaluation or training. + + Returns: + Distributed sampler. + """ + num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree + sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree + + return DistributedSampler( + dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + seed=self.args.seed if shuffle else None, + shuffle=shuffle, + drop_last=not is_eval, + ) + + def _sp_get_train_sampler(self, dataset) -> Sampler | None: + """ + Get a training sampler configured for sequence parallelism. + + Args: + dataset: The training dataset + + Returns: + Configured sequence parallel sampler. + """ + return self._create_sequence_parallel_sampler( + dataset, + shuffle=not self.args.curriculum_sampling, + ) + + def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: + """ + Get an evaluation sampler configured for sequence parallelism. + + Args: + eval_dataset: The evaluation dataset. + + Returns: + Configured sequence parallel sampler. + """ + return self._create_sequence_parallel_sampler( + eval_dataset, shuffle=False, is_eval=True + ) + + def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]): + """ + Calculate the cu_seqlens for the current forward pass and pass the value to + the substituted ring_flash_attn. This is accomplished by using the passed + `input_ids`. + + Args: + inputs: Current batch of inputs. + """ + # At this point, inputs should already be partitioned by the sequence + # parallel data collator + batch_size = inputs["input_ids"].shape[0] + seq_len = inputs["input_ids"].shape[1] + packed_seq_lens = [seq_len] * batch_size + + # Calculate the full sequence length across all GPUs in this SP group + total_seq_len = seq_len * self.args.sequence_parallel_degree + + cu_seqlens = torch.cumsum( + torch.tensor( + packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 + ), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad( + F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len + ) + + update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) diff --git a/src/axolotl/core/trainers/relora.py b/src/axolotl/core/trainers/relora.py new file mode 100644 index 000000000..3bcd4a9b8 --- /dev/null +++ b/src/axolotl/core/trainers/relora.py @@ -0,0 +1,43 @@ +"""Module for ReLoRA trainer""" + +import torch + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.monkeypatch.relora import ReLoRAScheduler + + +class ReLoRATrainer(AxolotlTrainer): + """Trainer subclass that uses the `OneCycleLR` scheduler""" + + tag_names = ["axolotl", "relora"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: torch.optim.Optimizer | None = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + + if self.args.relora_steps: + warmup_steps = ( + self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 + ) + anneal_steps = ( + self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 + ) + self.lr_scheduler = ReLoRAScheduler( + optimizer, + lr_scheduler, + self.args.relora_steps, + anneal_steps, + warmup_steps, + ) + else: + self.lr_scheduler = lr_scheduler + + return self.lr_scheduler diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index 7237e792e..1199313e8 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -1,16 +1,23 @@ -""" -module for TRL PPO training -""" +"""Module for TRL PPO trainer""" import torch from tqdm import tqdm -from trl import PPOTrainer +from trl import ( + CPOTrainer, + KTOTrainer, + ORPOTrainer, + PPOTrainer, + PRMTrainer, + RewardTrainer, +) + +from axolotl.core.trainers.mixins.scheduler import SchedulerMixin class TRLPPOTrainer(PPOTrainer): - """ - wrapper for ppo trainer to handle customizations - """ + """Wrapper for TRL PPO trainer to handle customizations""" + + tag_names = ["axolotl", "ppo"] def train( self, @@ -31,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer): "batch_size": 16, } - for epoch, batch in tqdm( # pylint: disable=unused-variable - enumerate(self.dataloader) - ): + for _, batch in tqdm(enumerate(self.dataloader)): query_tensors = batch["input_ids"] # generate model response @@ -65,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer): rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"], ) + + +class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + +class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): + """ + Extend the base KTOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "kto"] + + +class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): + """ + Extend the base CPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "cpo"] + + +class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): + """ + Extend the base RewardTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "reward"] + + +class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): + """ + Extend the base trl.PRMTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "prm"] diff --git a/src/axolotl/core/trainers/utils.py b/src/axolotl/core/trainers/utils.py new file mode 100644 index 000000000..c6d40cb61 --- /dev/null +++ b/src/axolotl/core/trainers/utils.py @@ -0,0 +1,33 @@ +"""Utils for Axolotl trainers""" + + +def sanitize_kwargs_for_tagging(tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + +def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): + if isinstance(dataset_tags, str): + dataset_tags = [dataset_tags] + + if (dataset_tags is not None) and (kwargs is not None): + if "dataset_tags" not in kwargs: + kwargs["dataset_tags"] = dataset_tags + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): + kwargs["dataset_tags"].extend(dataset_tags) + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): + dataset_tags.append(kwargs["dataset_tags"]) + kwargs["dataset_tags"] = dataset_tags + + return kwargs diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 34a79e646..82a62c049 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -207,14 +207,19 @@ class AxolotlTrainingMixins: }, ) + sequence_parallel_degree: Optional[int] = field( + default=1, + metadata={"help": "The number of workers to use in sequence parallelism"}, + ) + @dataclass class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ Training arguments for Causal trainer - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value - so it can't be used as a mixin. + This code is duplicated due to HF TrainingArguments not setting output_dir with a + default value so it can't be used as a mixin. """ diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py new file mode 100644 index 000000000..95c44a820 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -0,0 +1,89 @@ +""" +Ring attention group registration and flash attention patching. + +Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) +package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in +their sequence parallel version of Flash Attention 2. +""" + +import torch.distributed as dist +from accelerate.logging import get_logger + +from axolotl.logging_config import configure_logging + +configure_logging() +LOG = get_logger(__name__) + +RING_ATTN_GROUP = None + + +def get_ring_attn_group() -> dist.ProcessGroup: + """ + Getter for ring attention group on this rank. + + Returns: + The process group for ring attention for this rank. + """ + return RING_ATTN_GROUP + + +def set_ring_attn_group(ring_attn_group: dist.ProcessGroup): + """ + Setter for ring attention group on this rank. + + Args: + Process group for ring attention. + """ + global RING_ATTN_GROUP # pylint: disable=global-statement + RING_ATTN_GROUP = ring_attn_group + + +def register_ring_attn(sequence_parallel_degree: int): + """ + Create ring attention group and substitute flash attn with ring flash attn. + + Args: + sequence_parallel_degree: Sequence parallelism factor. + """ + LOG.info( + "Enabling ring attention sequence parallelism: " + f"each sequence will be processed across {sequence_parallel_degree} GPUs" + ) + + world_size = dist.get_world_size() + assert sequence_parallel_degree <= world_size, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must be less than or equal to world_size ({world_size})" + ) + assert world_size % sequence_parallel_degree == 0, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must evenly divide world_size ({world_size})" + ) + + # Detailed logging of group formation + rank = dist.get_rank() + group_assignments = {} + + for i in range(world_size // sequence_parallel_degree): + ring_attn_ranks = list( + range( + i * sequence_parallel_degree, + (i + 1) * sequence_parallel_degree, + ) + ) + group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") + + # Track which GPUs are in which groups + for r in ring_attn_ranks: + group_assignments[r] = i + + if rank in ring_attn_ranks: + set_ring_attn_group(group) + + # Log the GPU group assignments + if rank == 0: + LOG.info(f"Sequence parallel group assignments: {group_assignments}") + + from ring_flash_attn import substitute_hf_flash_attn + + substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index ff486db29..9ccd2ca0c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -169,7 +169,7 @@ def execute_training( cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None ): """ - Execute the training process with appropriate backend configurations. + Execute the training process with appropriate SDP kernel configurations. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -177,9 +177,6 @@ def execute_training( resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ LOG.info("Starting trainer...") - if cfg.group_by_length: - LOG.info("hang tight... sorting dataset for group_by_length") - if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 7cf771421..12c8b31d5 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,14 +1,59 @@ """ -DataCollator for axolotl to pad labels and position_ids for packed sequences +Data collators for axolotl to pad labels and position_ids for packed sequences. Also +includes logic for handling sequence parallelism collation. """ +import logging from dataclasses import dataclass from typing import Any, Optional, Union import numpy as np +import torch +import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +logger = logging.getLogger(__name__) + + +def adjust_position_ids_for_slice( + position_ids: torch.Tensor, start_idx: int +) -> torch.Tensor: + """ + Adjust position IDs for a sliced sequence to maintain proper relative positions. + This handles the case where position IDs might not be contiguous due to sample + packing. + """ + # Convert to tensor if not already + # Find the boundaries between samples (where position_ids reset) + adjusted_pos_ids = position_ids.clone() + + # Process each sequence in the batch + for i in range(position_ids.shape[0]): + seq = position_ids[i] + + # Find sample boundaries + boundaries = [] + for j in range(1, len(seq)): + if seq[j] < seq[j - 1]: + boundaries.append(j) + + # No need to adjust if there are no boundaries or this is a single sample + if not boundaries: + adjusted_pos_ids[i] = seq - start_idx + continue + + # Adjust each segment separately + prev_boundary = 0 + for boundary in boundaries: + adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx + prev_boundary = boundary + + # Last segment + adjusted_pos_ids[i, prev_boundary:] -= start_idx + + return adjusted_pos_ids + @dataclass class DataCollatorForSeq2Seq: @@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq: The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". + sequence_parallel_degree (`int`): + The degree of sequence parallelism. Default to 1 for no sequence parallelism. """ tokenizer: PreTrainedTokenizerBase @@ -53,6 +100,16 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" + sequence_parallel_degree: int = 1 + + def __post_init__(self): + if self.sequence_parallel_degree > 1: + from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + + # Get information about our position in the SP group + sp_group = get_ring_attn_group() + self.local_rank = dist.get_rank(group=sp_group) + self.local_world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None): labels = None @@ -119,8 +176,43 @@ class DataCollatorForSeq2Seq: ) features["decoder_input_ids"] = decoder_input_ids + if self.sequence_parallel_degree > 1: + features = self.apply_sequence_parallelism(features) + return features + def apply_sequence_parallelism( + self, batch: dict[str, torch.Tensor] + ) -> torch.Tensor: + """ + Apply sequence parallelism slicing to a batch. + + Args: + batch: Batch dictionary from parent collator. + + Returns: + Sliced batch dictionary. + """ + keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] + + for key in keys_to_slice: + if key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // self.local_world_size + start_idx = self.local_rank * slice_size + end_idx = ( + start_idx + slice_size + if self.local_rank < self.local_world_size - 1 + else seq_len + ) + batch[key] = batch[key][:, start_idx:end_idx] + + # Special handling for position_ids + if key == "position_ids" and self.local_rank > 0: + batch[key] = adjust_position_ids_for_slice(batch[key], start_idx) + + return batch + @dataclass class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): @@ -148,6 +240,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): np.array(item[feature]) for item in features_ if feature in item ] out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) @@ -177,6 +270,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): np.array(item[feature]) for item in features_ if feature in item ] out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 136acc4a0..4e956140d 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -125,6 +125,9 @@ def normalize_config(cfg): with open(ds_config_path, encoding="utf-8") as f: cfg.deepspeed = json.load(f) + if cfg.sequence_parallel_degree is None: + cfg.sequence_parallel_degree = 1 + if cfg.saves_per_epoch: save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) if save_steps < 1.0: # prevent saves on every step diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 44f570b88..83f70a022 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger(__name__) + +MULTIMODEL_AUTO_MODEL_MAPPING = { + "llava": LlavaForConditionalGeneration, + "mllama": MllamaForConditionalGeneration, +} # copied from accelerator.FullyShardedDataParallelPlugin @@ -476,7 +481,7 @@ class ModelLoader: else: self.text_model_config = self.model_config - self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name + self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: # load any patches from plugins @@ -547,6 +552,14 @@ class ModelLoader: patch_self_attn_lora(self.cfg) + if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: + from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + + # Initialize ring attn for sequence parallelism. This must be done after + # model init but before the first forward pass, since it modifies flash + # attn to use ring comm for SP training across multiple GPUs. + register_ring_attn(self.cfg.sequence_parallel_degree) + def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -603,7 +616,7 @@ class ModelLoader: patch_self_attn_lora() - def patch_llama_derived_model(self) -> None: + def patch_llama_derived_model(self): """Modify all llama derived models in one block""" self.patch_loss_llama() @@ -653,25 +666,16 @@ class ModelLoader: "Shifted-sparse attention not currently implemented without flash attention." ) - def set_auto_model_loader(self) -> None: - """set self.AutoModelLoader - - default value: AutoModelForCausalLM (set at __init__) - - when using a multi modality model, self.AutoModelLoader should - be set according to model type of the model + def set_auto_model_loader(self): + """ + Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM` + (set at `__init__`). When using a multimodal model, `self.auto_model_loader` + should be set according to the type of the model. """ if self.cfg.is_multimodal: - if self.model_config.model_type == "llava": - self.AutoModelLoader = ( # pylint: disable=invalid-name - LlavaForConditionalGeneration - ) - elif self.model_config.model_type == "mllama": - self.AutoModelLoader = ( # pylint: disable=invalid-name - MllamaForConditionalGeneration - ) - else: - self.AutoModelLoader = ( - AutoModelForVision2Seq # pylint: disable=invalid-name - ) + self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get( + self.model_config.model_type, AutoModelForVision2Seq + ) def set_device_map_config(self) -> None: device_map = self.cfg.device_map @@ -695,7 +699,7 @@ class ModelLoader: from accelerate import infer_auto_device_map with init_empty_weights(): - model_canvas = self.AutoModelLoader.from_config( + model_canvas = self.auto_model_loader.from_config( self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, ) @@ -916,11 +920,27 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) + + # Load model with random initialization if specified + if self.cfg.random_init_weights: + # AutoModel classes support the from_config method + if self.auto_model_loader in [ + AutoModelForCausalLM, + AutoModelForVision2Seq, + ]: + self.model = self.auto_model_loader.from_config( + config=self.model_config, + ) + else: + self.model = self.auto_model_loader( + config=self.model_config, + ) + else: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + **self.model_kwargs, + ) # TODO (MengqingCao) split these patches seperately if self.cfg.flash_attention and not self.inference: @@ -958,7 +978,7 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config if self.cfg.gptq: - self.model = self.AutoModelLoader.from_pretrained( + self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, @@ -991,7 +1011,7 @@ class ModelLoader: if self.cfg.gptq: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( + self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, @@ -1011,7 +1031,7 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( + self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, @@ -1307,7 +1327,7 @@ def load_model( """ Load a model for a given configuration and tokenizer. """ - loader = ModelLoader( + model_loader = ModelLoader( cfg, tokenizer, processor=processor, @@ -1315,7 +1335,7 @@ def load_model( reference_model=reference_model, **kwargs, ) - return loader.load_model() + return model_loader.load_model() def load_adapter(model, cfg, adapter, inference=False): diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 6119dff30..41095152e 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -104,9 +104,7 @@ def allocate( class MultipackBatchSampler(BatchSampler): - """ - Batch Sampler class for multipack - """ + """Batch sampler class for multipack""" def __init__( self, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7676a50a8..7992e6559 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1,4 +1,4 @@ -"""Main Axolotl input configuration Pydantic models""" +"""Module with Pydantic models for configuration.""" # pylint: disable=too-many-lines @@ -245,6 +245,8 @@ class AxolotlInputConfig( val_set_size: float | None = Field(default=0.0) + sequence_parallel_degree: int | None = None + special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None added_tokens_overrides: dict[int, str] | None = None @@ -1102,6 +1104,29 @@ class AxolotlInputConfig( return data + @field_validator("sequence_parallel_degree", mode="before") + @classmethod + def check_sequence_parallel_config(cls, value, info): + if not value: + value = 1 + + if value > 1: + if not info.data.get("flash_attention"): + raise ValueError( + "flash_attention: true must be set with sequence_parallel_degree > 1" + ) + + try: + import ring_flash_attn # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: + raise ImportError( + "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." + ) from exception + + return value + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 090e677a6..d2b211bbc 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - elif cfg.sample_packing: + elif cfg.sample_packing or cfg.sequence_parallel_degree > 1: drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" @@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): **filter_map_kwargs, **drop_long_kwargs, ) - if cfg.eval_sample_packing is not False: + if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids, @@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs + * cfg.sequence_parallel_degree ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", @@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est - total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) + total_num_steps = int( + math.floor( + data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree + ) + ) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") @@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) else: total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + math.ceil( + len(train_dataset) + * cfg.num_epochs + * cfg.sequence_parallel_degree + / cfg.batch_size + ) ) LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) return total_num_steps diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py new file mode 100644 index 000000000..a20ad9ff2 --- /dev/null +++ b/tests/e2e/patched/test_sp.py @@ -0,0 +1,207 @@ +"""Tests for sequence parallelism functionality.""" + +# pylint: disable=redefined-outer-name,unused-argument + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from accelerate.state import PartialState + +from axolotl.utils.dict import DictDefault + +# Use a single patch for ring_flash_attn if it's not available +ring_flash_attn_mock = MagicMock() +with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): + from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + from axolotl.utils.collators.batching import adjust_position_ids_for_slice + + +@pytest.fixture +def partial_state(): + """Create a real PartialState instance for testing.""" + state = PartialState() + return state + + +@pytest.fixture(name="cfg") +def fixture_cfg(): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-3, + "output_dir": "./model-out", + "sequence_len": 512, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + } + ) + + return cfg + + +class TestSequenceParallelHelpers: + """Test helper functions used in sequence parallelism.""" + + def test_adjust_position_ids_for_slice(self, partial_state): + """Test position_ids adjustment for sequence slices.""" + # Create sample position_ids with multiple sequences + position_ids = torch.tensor( + [ + # First sequence with 2 samples + [0, 1, 2, 3, 4, 0, 1, 2, 3], + # Second sequence with 3 samples + [0, 1, 2, 0, 1, 2, 3, 0, 1], + ] + ) + + # Adjust as if this was the second slice (start_idx = 4) + adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4) + + # For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1] + # For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3] + expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4 + expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4 + + assert torch.all(adjusted[0] == expected_first_seq) + assert torch.all(adjusted[1] == expected_second_seq) + + +class TestRingAttention: + """Tests for the ring attention functionality.""" + + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_rank") + @patch("torch.distributed.get_world_size") + def test_register_ring_attn( + self, mock_world_size, mock_rank, mock_new_group, partial_state + ): + """Test that ring attention groups are created correctly.""" + from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + + # Setup mocks + mock_world_size.return_value = 8 # 8 GPUs total + mock_rank.return_value = 3 # GPU #3 + mock_group = MagicMock() + mock_new_group.return_value = mock_group + + # Call register_ring_attn with size 4 + register_ring_attn(sequence_parallel_degree=4) + + # Verify the number of calls without examining the arguments + assert mock_new_group.call_count == 2 + + # Just verify that new_group was called + mock_new_group.assert_called() + + @patch("torch.distributed.get_rank") + @patch("torch.distributed.get_world_size") + def test_get_ring_attn_group_no_registration( + self, mock_world_size, mock_rank, partial_state + ): + """Test that get_ring_attn_group returns None when no group has been registered.""" + # Setup mocks + mock_world_size.return_value = 4 + mock_rank.return_value = 0 + + # Get the group without registration + group = get_ring_attn_group() + + # Verify that None was returned + assert group is None + + +# Mock a simplified DataCollator test +@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group") +@patch("torch.distributed.get_rank") +@patch("torch.distributed.get_world_size") +def test_sequence_parallel_slicing( + mock_world_size, mock_rank, mock_get_group, partial_state +): + """Test the basic sequence slicing logic without full collator instantiation.""" + # Setup mocks + mock_get_group.return_value = MagicMock() + mock_rank.return_value = 1 # Second GPU + mock_world_size.return_value = 4 # 4 GPUs total + + # Create a sample batch + batch = { + "input_ids": torch.tensor( + [ + [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112], + [201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212], + ] + ), + "attention_mask": torch.ones(2, 12), + } + + # Simplified slicing logic from SequenceParallelDataCollator + def slice_batch(batch, rank, world_size): + result = {} + for key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + result[key] = batch[key][:, start_idx:end_idx] + return result + + # Slice the batch + result = slice_batch( + batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value + ) + + # Check slicing + assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU + expected_input_ids = torch.tensor( + [ + [104, 105, 106], # Second slice of first sequence + [204, 205, 206], # Second slice of second sequence + ] + ) + assert torch.all(result["input_ids"] == expected_input_ids) + + +@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()}) +def test_config_validation_with_valid_inputs(cfg): + """Test that valid sequence parallelism configurations pass validation.""" + # Import the actual model class with appropriate mocks + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True + cfg = cfg | { + "sequence_parallel_degree": 2, + "flash_attention": True, + } + + # Should validate without errors + config = AxolotlInputConfig(**cfg) + assert config.sequence_parallel_degree == 2 + assert config.flash_attention is True + + +def test_config_validation_with_invalid_inputs(cfg): + """Test that invalid sequence parallelism configurations fail validation.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False + cfg = cfg | { + "sequence_parallel_degree": 2, + "flash_attention": False, + } + + # Should raise ValidationError + with pytest.raises(ValueError) as excinfo: + AxolotlInputConfig(**cfg) + + # Verify error message + assert "flash_attention: true must be set" in str(excinfo.value) diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 3fc315b2e..d32eb3953 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS from datasets import Dataset from transformers import AutoTokenizer +from axolotl.utils.config import normalize_config from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets @@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase): self.tokenizer.add_special_tokens(SPECIAL_TOKENS) self.cfg_1 = DictDefault( { + "base_model": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b", "sequence_len": 1024, "dataset_exact_deduplication": True, @@ -282,6 +284,7 @@ class TestDeduplicateNonRL(unittest.TestCase): "num_epochs": 1, } ) + normalize_config(self.cfg_1) def test_prepare_dataset_with_deduplication_train(self): """Verify that prepare_dataset function processes the dataset correctly with deduplication."""