Compare commits
44 Commits
no-zero-ds
...
sequence-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ac65462f0 | ||
|
|
ce35b2a95f | ||
|
|
ab3b36339a | ||
|
|
22cfa42961 | ||
|
|
0b2c2ed68c | ||
|
|
2f0b4626b9 | ||
|
|
a26985c53c | ||
|
|
c1a58339e8 | ||
|
|
411df76a97 | ||
|
|
a09d1ccbf2 | ||
|
|
2727d86544 | ||
|
|
64c203cdef | ||
|
|
7d7042f602 | ||
|
|
d187f1f8e2 | ||
|
|
1cced52719 | ||
|
|
11321b17e7 | ||
|
|
7a1a211c99 | ||
|
|
e1a02a32b5 | ||
|
|
a6ef6c7764 | ||
|
|
cb3a9e99a3 | ||
|
|
3ae47ec7de | ||
|
|
e36dc763ab | ||
|
|
03027cf6bf | ||
|
|
0ade60d455 | ||
|
|
02e1a42f04 | ||
|
|
919b88f11b | ||
|
|
345a9dd831 | ||
|
|
4ff97bc9d4 | ||
|
|
d0e178d52f | ||
|
|
5731cdc0cf | ||
|
|
b7738d57c4 | ||
|
|
698e599bf7 | ||
|
|
1d339e4007 | ||
|
|
4190ad0647 | ||
|
|
b44a207248 | ||
|
|
51c326150b | ||
|
|
14baaf6e0a | ||
|
|
f487910444 | ||
|
|
c5071dfd8a | ||
|
|
e323145ba9 | ||
|
|
7efc787ac8 | ||
|
|
dce61cdab1 | ||
|
|
bd952de9d2 | ||
|
|
3f8a43cab6 |
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -98,8 +98,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -172,8 +173,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
|
||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -3,9 +3,10 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -32,6 +32,9 @@ tokenizer_legacy:
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
@@ -617,6 +620,14 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
|
||||
90
docs/sequence_parallelism.qmd
Normal file
90
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,90 @@
|
||||
---
|
||||
title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
# Sequence Parallelism
|
||||
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
through a ring communication pattern.
|
||||
|
||||
## When to Use Sequence Parallelism
|
||||
|
||||
Use sequence parallelism when:
|
||||
|
||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||
- You have multiple GPUs available
|
||||
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When sequence parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
|
||||
To use sequence parallelism, you need:
|
||||
|
||||
- Multiple GPUs (at least 2)
|
||||
- The `ring-flash-attn` package. Install with:
|
||||
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||
- `pip install ring-flash-attn>=0.1.4`
|
||||
|
||||
## Limitations
|
||||
|
||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||
- May have a small performance overhead due to communication between GPUs
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
# Example config with sequence parallelism
|
||||
base_model: meta-llama/Llama-3-8B-Instruct
|
||||
sequence_len: 8192
|
||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
...
|
||||
```
|
||||
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
|
||||
1. Samples are first packed together
|
||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
@@ -4,7 +4,6 @@
|
||||
bitsandbytes==0.45.3
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.3
|
||||
@@ -36,6 +35,7 @@ einops
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
|
||||
12
setup.py
12
setup.py
@@ -17,11 +17,7 @@ def parse_requirements():
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
|
||||
)
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
@@ -39,7 +35,6 @@ def parse_requirements():
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
@@ -124,9 +119,8 @@ setup(
|
||||
],
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
],
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
|
||||
@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
"""
|
||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
del model, tokenizer, trainer
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
del trainer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from transformers import (
|
||||
from transformers.training_args import OptimizerNames
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
@@ -762,6 +762,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
elif self.cfg.process_reward_model:
|
||||
@@ -845,9 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if self.cfg.micro_batch_size > 1:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
@@ -875,9 +880,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if "max_length" in kwargs:
|
||||
kwargs.pop("max_length")
|
||||
elif use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
@@ -908,6 +911,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Init for axolotl.core.trainers"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -1,365 +1,47 @@
|
||||
"""
|
||||
module for customized trainers
|
||||
"""
|
||||
"""Module for customized trainers"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for shared handling of building custom optimizers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
@@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Initialize sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
@@ -394,142 +82,247 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
"""
|
||||
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||
for training.
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
Args:
|
||||
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||
dataset: Dataset to sample from.
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
Returns:
|
||||
Multipack (sample packing) batch sampler.
|
||||
"""
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
lengths=get_dataset_lengths(self.eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
return MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sequence
|
||||
parallelism, sample packing, and curriculum sampling (sequential).
|
||||
|
||||
Returns:
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||
elif self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
base_sampler = RandomSampler(self.train_dataset)
|
||||
else:
|
||||
# Default to parent class implementation for standard random sampling
|
||||
return super()._get_train_sampler()
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_sample_packing:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=self.train_dataset,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
return base_sampler
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||
and sample packing cases.
|
||||
|
||||
Returns:
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||
use_multipack = (
|
||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||
)
|
||||
|
||||
# Determine the base sampler
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||
elif use_multipack:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
else:
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_multipack:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||
"""Create common dataloader parameters for train or eval."""
|
||||
batch_size = custom_batch_size or (
|
||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||
)
|
||||
|
||||
params = {
|
||||
"batch_size": batch_size,
|
||||
"collate_fn": self.data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
# Add persistent workers only for training
|
||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||
|
||||
# Add prefetch factor if specified
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return params
|
||||
|
||||
def _prepare_dataloader(
|
||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||
):
|
||||
"""Prepare a dataloader with the given dataset and sampler."""
|
||||
# Get base parameters
|
||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||
|
||||
# Add sampler configuration
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if isinstance(sampler, BatchSampler):
|
||||
# batch_size and batch_sampler are mutually exclusive
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
if self.args.sample_packing and (
|
||||
(not is_eval and not self.args.pretraining)
|
||||
or (is_eval and self.args.eval_sample_packing is not False)
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator # type: ignore
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
data_collator,
|
||||
description="training",
|
||||
)
|
||||
|
||||
# Get sampler and create dataloader
|
||||
sampler = self._get_train_sampler()
|
||||
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""Get dataloader for evaluation"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset:
|
||||
if "length" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
# Handle sample packing or sequence parallelism
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
or self.args.sequence_parallel_degree > 1
|
||||
):
|
||||
# Get appropriate data collator
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||
else self.data_collator
|
||||
)
|
||||
if "length" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
# Handle dataset preprocessing for SP
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(
|
||||
eval_dataset, description="evaluation"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
self.data_collator, description="evaluation"
|
||||
)
|
||||
|
||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||
batch_size = (
|
||||
self.args.eval_batch_size
|
||||
if self.args.sample_packing
|
||||
else self.args.per_device_eval_batch_size
|
||||
)
|
||||
sampler = self._get_eval_sampler(eval_dataset)
|
||||
dataloader = self._prepare_dataloader(
|
||||
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return dataloader
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
) -> torch.utils.data.Sampler | None:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
@@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
@@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
@@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
logs: The values to log.
|
||||
start_time: The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
@@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
@@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
def training_step(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
# Proceed with normal training step
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
return loss
|
||||
|
||||
@@ -13,10 +13,10 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
SchedulerMixin,
|
||||
_sanitize_kwargs_for_ds_tagging,
|
||||
_sanitize_kwargs_for_tagging,
|
||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
|
||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Module for mamba trainer"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""Mamba specific trainer to handle loss calculation"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Init for axolotl.core.trainers.mixins"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .optimizer import OptimizerMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Module for Axolotl trainer optimizer mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
"""Mixin class for shared handling of building custom optimizers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Module for Axolotl trainer scheduler mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
Mixin class for sequence parallelism support in trainers.
|
||||
|
||||
This mixin provides functionality for handling sequence parallelism,
|
||||
including creating appropriate samplers, managing data partitioning,
|
||||
and updating ring flash attention parameters during training.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def _setup_sequence_parallel(self):
|
||||
"""Set up sequence parallelism environment."""
|
||||
self.ring_attn_group = get_ring_attn_group()
|
||||
|
||||
def _create_sequence_parallel_sampler(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
shuffle: bool = True,
|
||||
is_eval: bool = False,
|
||||
) -> DistributedSampler:
|
||||
"""
|
||||
Helper method to create sampler for sequence parallelism (SP).
|
||||
|
||||
We create a distributed sampler with rank equal to the SP group ID, which
|
||||
means that all ranks in the SP group receive the same sample / set of samples
|
||||
per training step. We also set the number of replicas equal to the number of
|
||||
SP groups, which is a bit of a hack / unintended use, but works!
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||
|
||||
Returns:
|
||||
Distributed sampler.
|
||||
"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
eval_dataset: The evaluation dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||
`input_ids`.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
@@ -1,16 +1,23 @@
|
||||
"""
|
||||
module for TRL PPO training
|
||||
"""
|
||||
"""Module for TRL PPO trainer"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import PPOTrainer
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""
|
||||
wrapper for ppo trainer to handle customizations
|
||||
"""
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -31,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||
enumerate(self.dataloader)
|
||||
):
|
||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
@@ -65,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Utils for Axolotl trainers"""
|
||||
|
||||
|
||||
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
@@ -207,14 +207,19 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sequence_parallel_degree: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||
default value so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Ring attention group registration and flash attention patching.
|
||||
|
||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
"""
|
||||
Getter for ring attention group on this rank.
|
||||
|
||||
Returns:
|
||||
The process group for ring attention for this rank.
|
||||
"""
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
||||
"""
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
Args:
|
||||
Process group for ring attention.
|
||||
"""
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
"""
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Detailed logging of group formation
|
||||
rank = dist.get_rank()
|
||||
group_assignments = {}
|
||||
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
|
||||
# Track which GPUs are in which groups
|
||||
for r in ring_attn_ranks:
|
||||
group_assignments[r] = i
|
||||
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||
@@ -169,7 +169,7 @@ def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
"""
|
||||
Execute the training process with appropriate backend configurations.
|
||||
Execute the training process with appropriate SDP kernel configurations.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -177,9 +177,6 @@ def execute_training(
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
|
||||
@@ -1,14 +1,59 @@
|
||||
"""
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||
includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample
|
||||
packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
# Process each sequence in the batch
|
||||
for i in range(position_ids.shape[0]):
|
||||
seq = position_ids[i]
|
||||
|
||||
# Find sample boundaries
|
||||
boundaries = []
|
||||
for j in range(1, len(seq)):
|
||||
if seq[j] < seq[j - 1]:
|
||||
boundaries.append(j)
|
||||
|
||||
# No need to adjust if there are no boundaries or this is a single sample
|
||||
if not boundaries:
|
||||
adjusted_pos_ids[i] = seq - start_idx
|
||||
continue
|
||||
|
||||
# Adjust each segment separately
|
||||
prev_boundary = 0
|
||||
for boundary in boundaries:
|
||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||
prev_boundary = boundary
|
||||
|
||||
# Last segment
|
||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||
|
||||
return adjusted_pos_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
@@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq:
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
sequence_parallel_degree (`int`):
|
||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
@@ -53,6 +100,16 @@ class DataCollatorForSeq2Seq:
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_degree: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(group=sp_group)
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
@@ -119,8 +176,43 @@ class DataCollatorForSeq2Seq:
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
if self.sequence_parallel_degree > 1:
|
||||
features = self.apply_sequence_parallelism(features)
|
||||
|
||||
return features
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
self, batch: dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from parent collator.
|
||||
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
# Special handling for position_ids
|
||||
if key == "position_ids" and self.local_rank > 0:
|
||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
@@ -148,6 +240,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@@ -177,6 +270,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
|
||||
@@ -125,6 +125,9 @@ def normalize_config(cfg):
|
||||
with open(ds_config_path, encoding="utf-8") as f:
|
||||
cfg.deepspeed = json.load(f)
|
||||
|
||||
if cfg.sequence_parallel_degree is None:
|
||||
cfg.sequence_parallel_degree = 1
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
|
||||
@@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MULTIMODEL_AUTO_MODEL_MAPPING = {
|
||||
"llava": LlavaForConditionalGeneration,
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
}
|
||||
|
||||
|
||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||
@@ -476,7 +481,7 @@ class ModelLoader:
|
||||
else:
|
||||
self.text_model_config = self.model_config
|
||||
|
||||
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
# load any patches from plugins
|
||||
@@ -547,6 +552,14 @@ class ModelLoader:
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
# attn to use ring comm for SP training across multiple GPUs.
|
||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -603,7 +616,7 @@ class ModelLoader:
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def patch_llama_derived_model(self) -> None:
|
||||
def patch_llama_derived_model(self):
|
||||
"""Modify all llama derived models in one block"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
@@ -653,25 +666,16 @@ class ModelLoader:
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
def set_auto_model_loader(self) -> None:
|
||||
"""set self.AutoModelLoader
|
||||
- default value: AutoModelForCausalLM (set at __init__)
|
||||
- when using a multi modality model, self.AutoModelLoader should
|
||||
be set according to model type of the model
|
||||
def set_auto_model_loader(self):
|
||||
"""
|
||||
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||
should be set according to the type of the model.
|
||||
"""
|
||||
if self.cfg.is_multimodal:
|
||||
if self.model_config.model_type == "llava":
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
LlavaForConditionalGeneration
|
||||
)
|
||||
elif self.model_config.model_type == "mllama":
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
MllamaForConditionalGeneration
|
||||
)
|
||||
else:
|
||||
self.AutoModelLoader = (
|
||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||
)
|
||||
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
|
||||
self.model_config.model_type, AutoModelForVision2Seq
|
||||
)
|
||||
|
||||
def set_device_map_config(self) -> None:
|
||||
device_map = self.cfg.device_map
|
||||
@@ -695,7 +699,7 @@ class ModelLoader:
|
||||
from accelerate import infer_auto_device_map
|
||||
|
||||
with init_empty_weights():
|
||||
model_canvas = self.AutoModelLoader.from_config(
|
||||
model_canvas = self.auto_model_loader.from_config(
|
||||
self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
)
|
||||
@@ -916,11 +920,27 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
# Load model with random initialization if specified
|
||||
if self.cfg.random_init_weights:
|
||||
# AutoModel classes support the from_config method
|
||||
if self.auto_model_loader in [
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
]:
|
||||
self.model = self.auto_model_loader.from_config(
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model = self.auto_model_loader(
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
# TODO (MengqingCao) split these patches seperately
|
||||
if self.cfg.flash_attention and not self.inference:
|
||||
@@ -958,7 +978,7 @@ class ModelLoader:
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
if self.cfg.gptq:
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -991,7 +1011,7 @@ class ModelLoader:
|
||||
if self.cfg.gptq:
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -1011,7 +1031,7 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -1307,7 +1327,7 @@ def load_model(
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
loader = ModelLoader(
|
||||
model_loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -1315,7 +1335,7 @@ def load_model(
|
||||
reference_model=reference_model,
|
||||
**kwargs,
|
||||
)
|
||||
return loader.load_model()
|
||||
return model_loader.load_model()
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
|
||||
@@ -104,9 +104,7 @@ def allocate(
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch Sampler class for multipack
|
||||
"""
|
||||
"""Batch sampler class for multipack"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Main Axolotl input configuration Pydantic models"""
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
@@ -245,6 +245,8 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
added_tokens_overrides: dict[int, str] | None = None
|
||||
@@ -1102,6 +1104,29 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, value, info):
|
||||
if not value:
|
||||
value = 1
|
||||
|
||||
if value > 1:
|
||||
if not info.data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids,
|
||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
)
|
||||
)
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
)
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||
return total_num_steps
|
||||
|
||||
207
tests/e2e/patched/test_sp.py
Normal file
207
tests/e2e/patched/test_sp.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
# Use a single patch for ring_flash_attn if it's not available
|
||||
ring_flash_attn_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
"""Create a real PartialState instance for testing."""
|
||||
state = PartialState()
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
class TestSequenceParallelHelpers:
|
||||
"""Test helper functions used in sequence parallelism."""
|
||||
|
||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||
"""Test position_ids adjustment for sequence slices."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
# First sequence with 2 samples
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||
# Second sequence with 3 samples
|
||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Adjust as if this was the second slice (start_idx = 4)
|
||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
||||
|
||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
||||
|
||||
assert torch.all(adjusted[0] == expected_first_seq)
|
||||
assert torch.all(adjusted[1] == expected_second_seq)
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_register_ring_attn(
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
mock_rank.return_value = 3 # GPU #3
|
||||
mock_group = MagicMock()
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_degree=4)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
# Just verify that new_group was called
|
||||
mock_new_group.assert_called()
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Get the group without registration
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify that None was returned
|
||||
assert group is None
|
||||
|
||||
|
||||
# Mock a simplified DataCollator test
|
||||
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_sequence_parallel_slicing(
|
||||
mock_world_size, mock_rank, mock_get_group, partial_state
|
||||
):
|
||||
"""Test the basic sequence slicing logic without full collator instantiation."""
|
||||
# Setup mocks
|
||||
mock_get_group.return_value = MagicMock()
|
||||
mock_rank.return_value = 1 # Second GPU
|
||||
mock_world_size.return_value = 4 # 4 GPUs total
|
||||
|
||||
# Create a sample batch
|
||||
batch = {
|
||||
"input_ids": torch.tensor(
|
||||
[
|
||||
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
|
||||
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
|
||||
]
|
||||
),
|
||||
"attention_mask": torch.ones(2, 12),
|
||||
}
|
||||
|
||||
# Simplified slicing logic from SequenceParallelDataCollator
|
||||
def slice_batch(batch, rank, world_size):
|
||||
result = {}
|
||||
for key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // world_size
|
||||
start_idx = rank * slice_size
|
||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
||||
result[key] = batch[key][:, start_idx:end_idx]
|
||||
return result
|
||||
|
||||
# Slice the batch
|
||||
result = slice_batch(
|
||||
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
|
||||
)
|
||||
|
||||
# Check slicing
|
||||
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
|
||||
expected_input_ids = torch.tensor(
|
||||
[
|
||||
[104, 105, 106], # Second slice of first sequence
|
||||
[204, 205, 206], # Second slice of second sequence
|
||||
]
|
||||
)
|
||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||
|
||||
|
||||
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
||||
def test_config_validation_with_valid_inputs(cfg):
|
||||
"""Test that valid sequence parallelism configurations pass validation."""
|
||||
# Import the actual model class with appropriate mocks
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
}
|
||||
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
assert config.sequence_parallel_degree == 2
|
||||
assert config.flash_attention is True
|
||||
|
||||
|
||||
def test_config_validation_with_invalid_inputs(cfg):
|
||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": False,
|
||||
}
|
||||
|
||||
# Should raise ValidationError
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
|
||||
# Verify error message
|
||||
assert "flash_attention: true must be set" in str(excinfo.value)
|
||||
@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg_1 = DictDefault(
|
||||
{
|
||||
"base_model": "huggyllama/llama-7b",
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"dataset_exact_deduplication": True,
|
||||
@@ -282,6 +284,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
}
|
||||
)
|
||||
normalize_config(self.cfg_1)
|
||||
|
||||
def test_prepare_dataset_with_deduplication_train(self):
|
||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||
|
||||
Reference in New Issue
Block a user