Compare commits
44 Commits
fix-previe
...
sequence-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ac65462f0 | ||
|
|
ce35b2a95f | ||
|
|
ab3b36339a | ||
|
|
22cfa42961 | ||
|
|
0b2c2ed68c | ||
|
|
2f0b4626b9 | ||
|
|
a26985c53c | ||
|
|
c1a58339e8 | ||
|
|
411df76a97 | ||
|
|
a09d1ccbf2 | ||
|
|
2727d86544 | ||
|
|
64c203cdef | ||
|
|
7d7042f602 | ||
|
|
d187f1f8e2 | ||
|
|
1cced52719 | ||
|
|
11321b17e7 | ||
|
|
7a1a211c99 | ||
|
|
e1a02a32b5 | ||
|
|
a6ef6c7764 | ||
|
|
cb3a9e99a3 | ||
|
|
3ae47ec7de | ||
|
|
e36dc763ab | ||
|
|
03027cf6bf | ||
|
|
0ade60d455 | ||
|
|
02e1a42f04 | ||
|
|
919b88f11b | ||
|
|
345a9dd831 | ||
|
|
4ff97bc9d4 | ||
|
|
d0e178d52f | ||
|
|
5731cdc0cf | ||
|
|
b7738d57c4 | ||
|
|
698e599bf7 | ||
|
|
1d339e4007 | ||
|
|
4190ad0647 | ||
|
|
b44a207248 | ||
|
|
51c326150b | ||
|
|
14baaf6e0a | ||
|
|
f487910444 | ||
|
|
c5071dfd8a | ||
|
|
e323145ba9 | ||
|
|
7efc787ac8 | ||
|
|
dce61cdab1 | ||
|
|
bd952de9d2 | ||
|
|
3f8a43cab6 |
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -98,8 +98,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -172,8 +173,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||||
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
|
# pre-training a model from scratch or debugging purposes.
|
||||||
|
random_init_weights:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -617,6 +620,14 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
|
# Sequence parallelism
|
||||||
|
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||||
|
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||||
|
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||||
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
|
sequence_parallel_degree:
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
90
docs/sequence_parallelism.qmd
Normal file
90
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
---
|
||||||
|
title: Sequence Parallelism
|
||||||
|
description: Train with long sequences split across multiple GPUs.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||||
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||||
|
GPU processes a different portion of the sequence, and the results are aggregated
|
||||||
|
through a ring communication pattern.
|
||||||
|
|
||||||
|
## When to Use Sequence Parallelism
|
||||||
|
|
||||||
|
Use sequence parallelism when:
|
||||||
|
|
||||||
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||||
|
- You have multiple GPUs available
|
||||||
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable sequence parallelism, add the following to your configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
|
```
|
||||||
|
|
||||||
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
When sequence parallelism is enabled:
|
||||||
|
|
||||||
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||||
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||||
|
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||||
|
4. The trainer uses special ring communication patterns for attention operations
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
To use sequence parallelism, you need:
|
||||||
|
|
||||||
|
- Multiple GPUs (at least 2)
|
||||||
|
- The `ring-flash-attn` package. Install with:
|
||||||
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||||
|
- `pip install ring-flash-attn>=0.1.4`
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||||
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example config with sequence parallelism
|
||||||
|
base_model: meta-llama/Llama-3-8B-Instruct
|
||||||
|
sequence_len: 8192
|
||||||
|
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||||
|
flash_attention: true # Required with sequence parallelism
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||||
|
into 2 subsequences of length 4096 across 2 GPUs.
|
||||||
|
|
||||||
|
## Sample Packing with Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||||
|
|
||||||
|
1. Samples are first packed together
|
||||||
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||||
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||||
|
|
||||||
|
## Effect on Batch Size
|
||||||
|
|
||||||
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.3
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
flash-attn==2.7.4.post1
|
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.3
|
liger-kernel==0.5.3
|
||||||
@@ -36,6 +35,7 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
|
|||||||
12
setup.py
12
setup.py
@@ -17,11 +17,7 @@ def parse_requirements():
|
|||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = (
|
is_extras = (
|
||||||
"flash-attn" in line
|
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
|
||||||
or "flash-attention" in line
|
|
||||||
or "deepspeed" in line
|
|
||||||
or "mamba-ssm" in line
|
|
||||||
or "lion-pytorch" in line
|
|
||||||
)
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
@@ -39,7 +35,6 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -124,9 +119,8 @@ setup(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||||
"flash-attn==2.7.4.post1",
|
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||||
],
|
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
del trainer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from transformers import (
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
@@ -762,6 +762,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
|
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
@@ -845,9 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if (
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
if self.cfg.micro_batch_size > 1:
|
or self.cfg.micro_batch_size > 1
|
||||||
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -875,9 +880,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif (
|
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
@@ -908,6 +911,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
|
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||||
|
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""Init for axolotl.core.trainers"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .base import AxolotlTrainer
|
||||||
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
|
from .grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from .mamba import AxolotlMambaTrainer
|
||||||
|
from .relora import ReLoRATrainer
|
||||||
|
from .trl import (
|
||||||
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlKTOTrainer,
|
||||||
|
AxolotlORPOTrainer,
|
||||||
|
AxolotlPRMTrainer,
|
||||||
|
AxolotlRewardTrainer,
|
||||||
|
TRLPPOTrainer,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,365 +1,47 @@
|
|||||||
"""
|
"""Module for customized trainers"""
|
||||||
module for customized trainers
|
|
||||||
"""
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Dict, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.utils.data import (
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
BatchSampler,
|
||||||
|
DataLoader,
|
||||||
|
RandomSampler,
|
||||||
|
Sampler,
|
||||||
|
SequentialSampler,
|
||||||
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.core.trainers.mixins import (
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
OptimizerMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
SequenceParallelMixin,
|
||||||
|
)
|
||||||
|
from axolotl.core.trainers.utils import (
|
||||||
|
sanitize_kwargs_for_ds_tagging,
|
||||||
|
sanitize_kwargs_for_tagging,
|
||||||
|
)
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
|
||||||
RexLR,
|
|
||||||
get_cosine_schedule_with_min_lr,
|
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
LOG = logging.getLogger(__name__)
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||||
if isinstance(tag_names, str):
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
|
||||||
if isinstance(dataset_tags, str):
|
|
||||||
dataset_tags = [dataset_tags]
|
|
||||||
|
|
||||||
if (dataset_tags is not None) and (kwargs is not None):
|
|
||||||
if "dataset_tags" not in kwargs:
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
|
||||||
kwargs["dataset_tags"].extend(dataset_tags)
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
|
||||||
dataset_tags.append(kwargs["dataset_tags"])
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for scheduler setup in CausalTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
extra_lr_kwargs = {}
|
|
||||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["pct_start"] = pct_start
|
|
||||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
**extra_lr_kwargs,
|
|
||||||
**self.args.lr_scheduler_kwargs,
|
|
||||||
)
|
|
||||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
|
|
||||||
self.lr_scheduler = RexLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
)
|
|
||||||
elif use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for shared handling of building custom optimizers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(
|
|
||||||
self, opt_model, optimizer_kwargs
|
|
||||||
) -> list[dict]:
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params: dict = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
lr_group_modules = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
]
|
|
||||||
if lr_groups_lookup and any(lr_group_modules):
|
|
||||||
lr_group_module = lr_group_modules[0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.optimizer_cls_and_kwargs is None
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.optimizer
|
|
||||||
and self.optimizer_cls_and_kwargs is not None
|
|
||||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
|
||||||
):
|
|
||||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
self.optimizer = optimizer_factory_cls()(
|
|
||||||
opt_model, self.args, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.optimizer:
|
|
||||||
if self.optimizer_cls_and_kwargs is not None:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
else:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args, opt_model
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", 1e-6
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for GaLore optimizer.
|
|
||||||
if "params" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
|
||||||
|
|
||||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for LOMO optimizer.
|
|
||||||
if "model" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
|
||||||
|
|
||||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
|
||||||
# to avoid arguments conflicts.
|
|
||||||
if "optimizer_dict" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
|
||||||
"optimizer_dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(
|
|
||||||
optimizer_grouped_parameters, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{
|
|
||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
|
||||||
}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(
|
|
||||||
module, "weight", {"optim_bits": 32}
|
|
||||||
)
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|
||||||
"""
|
|
||||||
Extend the base Trainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
# Initialize sequence parallelism if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._setup_sequence_parallel()
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -394,8 +82,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _create_multipack_sampler(
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
self, base_sampler: Sampler, dataset: Dataset
|
||||||
|
) -> MultipackBatchSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||||
|
for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Multipack (sample packing) batch sampler.
|
||||||
|
"""
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
batch_size = self.args.per_device_train_batch_size
|
batch_size = self.args.per_device_train_batch_size
|
||||||
batch_max_len = self.args.max_seq_length
|
batch_max_len = self.args.max_seq_length
|
||||||
@@ -406,130 +106,223 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
sampler = SequentialSampler(self.train_dataset)
|
|
||||||
else:
|
|
||||||
sampler = RandomSampler(self.train_dataset)
|
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
sampler,
|
base_sampler,
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
return SequentialSampler(self.train_dataset)
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for training. Handles cases for sequence
|
||||||
|
parallelism, sample packing, and curriculum sampling (sequential).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
|
# Determine the base sampler first
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||||
|
elif self.args.curriculum_sampling:
|
||||||
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
|
elif use_sample_packing:
|
||||||
|
base_sampler = RandomSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
# Default to parent class implementation for standard random sampling
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
# Apply multipack wrapper if needed
|
||||||
self, eval_dataset: Dataset
|
if use_sample_packing:
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
return self._create_multipack_sampler(
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
base_sampler=base_sampler,
|
||||||
if self.args.multipack_real_batches:
|
dataset=self.train_dataset,
|
||||||
batch_size = self.args.per_device_eval_batch_size
|
)
|
||||||
batch_max_len = self.args.max_seq_length
|
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
|
and sample packing cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
|
use_multipack = (
|
||||||
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine the base sampler
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||||
|
elif use_multipack:
|
||||||
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
batch_size = 1
|
|
||||||
batch_max_len = (
|
|
||||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
|
||||||
)
|
|
||||||
return MultipackBatchSampler(
|
|
||||||
SequentialSampler(eval_dataset),
|
|
||||||
lengths=get_dataset_lengths(self.eval_dataset),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
batch_size=batch_size,
|
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
# Apply multipack wrapper if needed
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if use_multipack:
|
||||||
train_dataset = self.train_dataset
|
return self._create_multipack_sampler(
|
||||||
if "length" in train_dataset.features.keys():
|
base_sampler=base_sampler,
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
dataset=eval_dataset,
|
||||||
data_collator = self.data_collator
|
)
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size,
|
return base_sampler
|
||||||
"collate_fn": data_collator,
|
|
||||||
|
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||||
|
"""Create common dataloader parameters for train or eval."""
|
||||||
|
batch_size = custom_batch_size or (
|
||||||
|
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"collate_fn": self.data_collator,
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
# Add persistent workers only for training
|
||||||
|
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||||
|
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||||
|
|
||||||
|
# Add prefetch factor if specified
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _prepare_dataloader(
|
||||||
|
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||||
|
):
|
||||||
|
"""Prepare a dataloader with the given dataset and sampler."""
|
||||||
|
# Get base parameters
|
||||||
|
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||||
|
|
||||||
|
# Add sampler configuration
|
||||||
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
if not is_eval:
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
# Create the dataloader
|
||||||
return self.accelerator.prepare_data_loader(
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
return super().get_train_dataloader()
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
if self.args.sample_packing and (
|
||||||
|
(not is_eval and not self.args.pretraining)
|
||||||
|
or (is_eval and self.args.eval_sample_packing is not False)
|
||||||
|
):
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
# Otherwise prepare with accelerator
|
||||||
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""Get dataloader for training"""
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator # type: ignore
|
||||||
|
|
||||||
|
# Handle dataset preprocessing
|
||||||
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
if not self.args.sample_packing or self.args.pretraining:
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
data_collator,
|
||||||
|
description="training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sampler and create dataloader
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||||
|
"""Get dataloader for evaluation"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
if "length" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
# Handle sample packing or sequence parallelism
|
||||||
eval_dataset = (
|
if (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
self.args.sample_packing
|
||||||
|
and self.args.eval_sample_packing is not False
|
||||||
|
or self.args.sequence_parallel_degree > 1
|
||||||
|
):
|
||||||
|
# Get appropriate data collator
|
||||||
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.eval_data_collator
|
||||||
|
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||||
|
else self.data_collator
|
||||||
)
|
)
|
||||||
|
if "length" in eval_dataset.column_names:
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self.args.eval_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
# Handle dataset preprocessing for SP
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
if self.args.sequence_parallel_degree > 1:
|
||||||
del dataloader_params["batch_size"]
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
|
eval_dataset = self._remove_unused_columns(
|
||||||
|
eval_dataset, description="evaluation"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = eval_sampler
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
self.data_collator, description="evaluation"
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||||
|
batch_size = (
|
||||||
|
self.args.eval_batch_size
|
||||||
|
if self.args.sample_packing
|
||||||
|
else self.args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader = self._prepare_dataloader(
|
||||||
|
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> torch.utils.data.Sampler | None:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs: The values to log.
|
||||||
The values to log.
|
start_time: The start of training.
|
||||||
start_time (`Optional[float]`):
|
|
||||||
The start of training.
|
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def training_step(
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Mamba specific trainer to handle loss calculation
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
self,
|
||||||
model,
|
model: nn.Module,
|
||||||
inputs,
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
num_items_in_batch: int | None = None,
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
) -> torch.Tensor:
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""
|
"""
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Perform a training step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform training step for.
|
||||||
|
inputs: Dictionary mapping.
|
||||||
"""
|
"""
|
||||||
|
# Set up sequence parallelism for this step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
# Proceed with normal training step
|
||||||
|
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
return loss
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||||
SchedulerMixin,
|
from axolotl.core.trainers.utils import (
|
||||||
_sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
_sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Module for mamba trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""Mamba specific trainer to handle loss calculation"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Init for axolotl.core.trainers.mixins"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .optimizer import OptimizerMixin
|
||||||
|
from .scheduler import SchedulerMixin
|
||||||
|
from .sequence_parallel import SequenceParallelMixin
|
||||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Module for Axolotl trainer optimizer mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerMixin(Trainer):
|
||||||
|
"""Mixin class for shared handling of building custom optimizers"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(
|
||||||
|
self, opt_model, optimizer_kwargs
|
||||||
|
) -> list[dict]:
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params: dict = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
lr_group_modules = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
]
|
||||||
|
if lr_groups_lookup and any(lr_group_modules):
|
||||||
|
lr_group_module = lr_group_modules[0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.optimizer_cls_and_kwargs is None
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.optimizer
|
||||||
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = optimizer_cls(
|
||||||
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Module for Axolotl trainer scheduler mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
|
get_cosine_schedule_with_min_lr,
|
||||||
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
except ImportError:
|
||||||
|
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||||
|
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for sequence parallelism support in trainers.
|
||||||
|
|
||||||
|
This mixin provides functionality for handling sequence parallelism,
|
||||||
|
including creating appropriate samplers, managing data partitioning,
|
||||||
|
and updating ring flash attention parameters during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def _setup_sequence_parallel(self):
|
||||||
|
"""Set up sequence parallelism environment."""
|
||||||
|
self.ring_attn_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
def _create_sequence_parallel_sampler(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
shuffle: bool = True,
|
||||||
|
is_eval: bool = False,
|
||||||
|
) -> DistributedSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create sampler for sequence parallelism (SP).
|
||||||
|
|
||||||
|
We create a distributed sampler with rank equal to the SP group ID, which
|
||||||
|
means that all ranks in the SP group receive the same sample / set of samples
|
||||||
|
per training step. We also set the number of replicas equal to the number of
|
||||||
|
SP groups, which is a bit of a hack / unintended use, but works!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
shuffle: Whether to shuffle the dataset.
|
||||||
|
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Distributed sampler.
|
||||||
|
"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The training dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=not self.args.curriculum_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: The evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||||
|
`input_ids`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Current batch of inputs.
|
||||||
|
"""
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Module for ReLoRA trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
|
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: torch.optim.Optimizer | None = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
@@ -1,16 +1,23 @@
|
|||||||
"""
|
"""Module for TRL PPO trainer"""
|
||||||
module for TRL PPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import PPOTrainer
|
from trl import (
|
||||||
|
CPOTrainer,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOTrainer,
|
||||||
|
PPOTrainer,
|
||||||
|
PRMTrainer,
|
||||||
|
RewardTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
wrapper for ppo trainer to handle customizations
|
|
||||||
"""
|
tag_names = ["axolotl", "ppo"]
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
@@ -31,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||||
enumerate(self.dataloader)
|
|
||||||
):
|
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
# generate model response
|
# generate model response
|
||||||
@@ -65,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
rewards,
|
rewards,
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Utils for Axolotl trainers"""
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
@@ -207,14 +207,19 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sequence_parallel_degree: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||||
so it can't be used as a mixin.
|
default value so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Ring attention group registration and flash attention patching.
|
||||||
|
|
||||||
|
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||||
|
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||||
|
their sequence parallel version of Flash Attention 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||||
|
"""
|
||||||
|
Getter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The process group for ring attention for this rank.
|
||||||
|
"""
|
||||||
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
||||||
|
"""
|
||||||
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Process group for ring attention.
|
||||||
|
"""
|
||||||
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
|
def register_ring_attn(sequence_parallel_degree: int):
|
||||||
|
"""
|
||||||
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
|
"""
|
||||||
|
LOG.info(
|
||||||
|
"Enabling ring attention sequence parallelism: "
|
||||||
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
assert sequence_parallel_degree <= world_size, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
|
)
|
||||||
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detailed logging of group formation
|
||||||
|
rank = dist.get_rank()
|
||||||
|
group_assignments = {}
|
||||||
|
|
||||||
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
|
ring_attn_ranks = list(
|
||||||
|
range(
|
||||||
|
i * sequence_parallel_degree,
|
||||||
|
(i + 1) * sequence_parallel_degree,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
|
|
||||||
|
# Track which GPUs are in which groups
|
||||||
|
for r in ring_attn_ranks:
|
||||||
|
group_assignments[r] = i
|
||||||
|
|
||||||
|
if rank in ring_attn_ranks:
|
||||||
|
set_ring_attn_group(group)
|
||||||
|
|
||||||
|
# Log the GPU group assignments
|
||||||
|
if rank == 0:
|
||||||
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
|
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||||
@@ -169,7 +169,7 @@ def execute_training(
|
|||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the training process with appropriate backend configurations.
|
Execute the training process with appropriate SDP kernel configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -177,9 +177,6 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
|
|||||||
@@ -1,14 +1,59 @@
|
|||||||
"""
|
"""
|
||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||||
|
includes logic for handling sequence parallelism collation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_position_ids_for_slice(
|
||||||
|
position_ids: torch.Tensor, start_idx: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||||
|
This handles the case where position IDs might not be contiguous due to sample
|
||||||
|
packing.
|
||||||
|
"""
|
||||||
|
# Convert to tensor if not already
|
||||||
|
# Find the boundaries between samples (where position_ids reset)
|
||||||
|
adjusted_pos_ids = position_ids.clone()
|
||||||
|
|
||||||
|
# Process each sequence in the batch
|
||||||
|
for i in range(position_ids.shape[0]):
|
||||||
|
seq = position_ids[i]
|
||||||
|
|
||||||
|
# Find sample boundaries
|
||||||
|
boundaries = []
|
||||||
|
for j in range(1, len(seq)):
|
||||||
|
if seq[j] < seq[j - 1]:
|
||||||
|
boundaries.append(j)
|
||||||
|
|
||||||
|
# No need to adjust if there are no boundaries or this is a single sample
|
||||||
|
if not boundaries:
|
||||||
|
adjusted_pos_ids[i] = seq - start_idx
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Adjust each segment separately
|
||||||
|
prev_boundary = 0
|
||||||
|
for boundary in boundaries:
|
||||||
|
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||||
|
prev_boundary = boundary
|
||||||
|
|
||||||
|
# Last segment
|
||||||
|
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||||
|
|
||||||
|
return adjusted_pos_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
sequence_parallel_degree (`int`):
|
||||||
|
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -53,6 +100,16 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
sequence_parallel_degree: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
# Get information about our position in the SP group
|
||||||
|
sp_group = get_ring_attn_group()
|
||||||
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -119,8 +176,43 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
if self.sequence_parallel_degree > 1:
|
||||||
|
features = self.apply_sequence_parallelism(features)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def apply_sequence_parallelism(
|
||||||
|
self, batch: dict[str, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Batch dictionary from parent collator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sliced batch dictionary.
|
||||||
|
"""
|
||||||
|
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||||
|
|
||||||
|
for key in keys_to_slice:
|
||||||
|
if key in batch:
|
||||||
|
seq_len = batch[key].shape[1]
|
||||||
|
slice_size = seq_len // self.local_world_size
|
||||||
|
start_idx = self.local_rank * slice_size
|
||||||
|
end_idx = (
|
||||||
|
start_idx + slice_size
|
||||||
|
if self.local_rank < self.local_world_size - 1
|
||||||
|
else seq_len
|
||||||
|
)
|
||||||
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
|
# Special handling for position_ids
|
||||||
|
if key == "position_ids" and self.local_rank > 0:
|
||||||
|
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
@@ -148,6 +240,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,6 +270,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,9 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
|
if cfg.sequence_parallel_degree is None:
|
||||||
|
cfg.sequence_parallel_degree = 1
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
|
|||||||
@@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
|||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MULTIMODEL_AUTO_MODEL_MAPPING = {
|
||||||
|
"llava": LlavaForConditionalGeneration,
|
||||||
|
"mllama": MllamaForConditionalGeneration,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
@@ -476,7 +481,7 @@ class ModelLoader:
|
|||||||
else:
|
else:
|
||||||
self.text_model_config = self.model_config
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
@@ -547,6 +552,14 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
|
# model init but before the first forward pass, since it modifies flash
|
||||||
|
# attn to use ring comm for SP training across multiple GPUs.
|
||||||
|
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -603,7 +616,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def patch_llama_derived_model(self) -> None:
|
def patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block"""
|
"""Modify all llama derived models in one block"""
|
||||||
self.patch_loss_llama()
|
self.patch_loss_llama()
|
||||||
|
|
||||||
@@ -653,24 +666,15 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_auto_model_loader(self) -> None:
|
def set_auto_model_loader(self):
|
||||||
"""set self.AutoModelLoader
|
"""
|
||||||
- default value: AutoModelForCausalLM (set at __init__)
|
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
- when using a multi modality model, self.AutoModelLoader should
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
be set according to model type of the model
|
should be set according to the type of the model.
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
if self.model_config.model_type == "llava":
|
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
self.model_config.model_type, AutoModelForVision2Seq
|
||||||
LlavaForConditionalGeneration
|
|
||||||
)
|
|
||||||
elif self.model_config.model_type == "mllama":
|
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
|
||||||
MllamaForConditionalGeneration
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.AutoModelLoader = (
|
|
||||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_device_map_config(self) -> None:
|
def set_device_map_config(self) -> None:
|
||||||
@@ -695,7 +699,7 @@ class ModelLoader:
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = self.AutoModelLoader.from_config(
|
model_canvas = self.auto_model_loader.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
@@ -916,7 +920,23 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
|
||||||
|
# Load model with random initialization if specified
|
||||||
|
if self.cfg.random_init_weights:
|
||||||
|
# AutoModel classes support the from_config method
|
||||||
|
if self.auto_model_loader in [
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
]:
|
||||||
|
self.model = self.auto_model_loader.from_config(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
@@ -958,7 +978,7 @@ class ModelLoader:
|
|||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -991,7 +1011,7 @@ class ModelLoader:
|
|||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1011,7 +1031,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1307,7 +1327,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
loader = ModelLoader(
|
model_loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@@ -1315,7 +1335,7 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return loader.load_model()
|
return model_loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
@@ -104,9 +104,7 @@ def allocate(
|
|||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
class MultipackBatchSampler(BatchSampler):
|
||||||
"""
|
"""Batch sampler class for multipack"""
|
||||||
Batch Sampler class for multipack
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Main Axolotl input configuration Pydantic models"""
|
"""Module with Pydantic models for configuration."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
@@ -245,6 +245,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
val_set_size: float | None = Field(default=0.0)
|
val_set_size: float | None = Field(default=0.0)
|
||||||
|
|
||||||
|
sequence_parallel_degree: int | None = None
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
added_tokens_overrides: dict[int, str] | None = None
|
added_tokens_overrides: dict[int, str] | None = None
|
||||||
@@ -1102,6 +1104,29 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@field_validator("sequence_parallel_degree", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_sequence_parallel_config(cls, value, info):
|
||||||
|
if not value:
|
||||||
|
value = 1
|
||||||
|
|
||||||
|
if value > 1:
|
||||||
|
if not info.data.get("flash_attention"):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
|
except ImportError as exception:
|
||||||
|
raise ImportError(
|
||||||
|
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||||
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
|
) from exception
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
total_num_steps = int(
|
||||||
|
math.floor(
|
||||||
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(
|
||||||
|
len(train_dataset)
|
||||||
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
207
tests/e2e/patched/test_sp.py
Normal file
207
tests/e2e/patched/test_sp.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Tests for sequence parallelism functionality."""
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name,unused-argument
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from accelerate.state import PartialState
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
# Use a single patch for ring_flash_attn if it's not available
|
||||||
|
ring_flash_attn_mock = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def partial_state():
|
||||||
|
"""Create a real PartialState instance for testing."""
|
||||||
|
state = PartialState()
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-3,
|
||||||
|
"output_dir": "./model-out",
|
||||||
|
"sequence_len": 512,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceParallelHelpers:
|
||||||
|
"""Test helper functions used in sequence parallelism."""
|
||||||
|
|
||||||
|
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||||
|
"""Test position_ids adjustment for sequence slices."""
|
||||||
|
# Create sample position_ids with multiple sequences
|
||||||
|
position_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
# First sequence with 2 samples
|
||||||
|
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||||
|
# Second sequence with 3 samples
|
||||||
|
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjust as if this was the second slice (start_idx = 4)
|
||||||
|
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
||||||
|
|
||||||
|
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
||||||
|
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
||||||
|
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
||||||
|
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
||||||
|
|
||||||
|
assert torch.all(adjusted[0] == expected_first_seq)
|
||||||
|
assert torch.all(adjusted[1] == expected_second_seq)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRingAttention:
|
||||||
|
"""Tests for the ring attention functionality."""
|
||||||
|
|
||||||
|
@patch("torch.distributed.new_group")
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_register_ring_attn(
|
||||||
|
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||||
|
):
|
||||||
|
"""Test that ring attention groups are created correctly."""
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_world_size.return_value = 8 # 8 GPUs total
|
||||||
|
mock_rank.return_value = 3 # GPU #3
|
||||||
|
mock_group = MagicMock()
|
||||||
|
mock_new_group.return_value = mock_group
|
||||||
|
|
||||||
|
# Call register_ring_attn with size 4
|
||||||
|
register_ring_attn(sequence_parallel_degree=4)
|
||||||
|
|
||||||
|
# Verify the number of calls without examining the arguments
|
||||||
|
assert mock_new_group.call_count == 2
|
||||||
|
|
||||||
|
# Just verify that new_group was called
|
||||||
|
mock_new_group.assert_called()
|
||||||
|
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_get_ring_attn_group_no_registration(
|
||||||
|
self, mock_world_size, mock_rank, partial_state
|
||||||
|
):
|
||||||
|
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_world_size.return_value = 4
|
||||||
|
mock_rank.return_value = 0
|
||||||
|
|
||||||
|
# Get the group without registration
|
||||||
|
group = get_ring_attn_group()
|
||||||
|
|
||||||
|
# Verify that None was returned
|
||||||
|
assert group is None
|
||||||
|
|
||||||
|
|
||||||
|
# Mock a simplified DataCollator test
|
||||||
|
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_sequence_parallel_slicing(
|
||||||
|
mock_world_size, mock_rank, mock_get_group, partial_state
|
||||||
|
):
|
||||||
|
"""Test the basic sequence slicing logic without full collator instantiation."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_get_group.return_value = MagicMock()
|
||||||
|
mock_rank.return_value = 1 # Second GPU
|
||||||
|
mock_world_size.return_value = 4 # 4 GPUs total
|
||||||
|
|
||||||
|
# Create a sample batch
|
||||||
|
batch = {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[
|
||||||
|
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
|
||||||
|
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"attention_mask": torch.ones(2, 12),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simplified slicing logic from SequenceParallelDataCollator
|
||||||
|
def slice_batch(batch, rank, world_size):
|
||||||
|
result = {}
|
||||||
|
for key in batch:
|
||||||
|
seq_len = batch[key].shape[1]
|
||||||
|
slice_size = seq_len // world_size
|
||||||
|
start_idx = rank * slice_size
|
||||||
|
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
||||||
|
result[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Slice the batch
|
||||||
|
result = slice_batch(
|
||||||
|
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check slicing
|
||||||
|
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
|
||||||
|
expected_input_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
[104, 105, 106], # Second slice of first sequence
|
||||||
|
[204, 205, 206], # Second slice of second sequence
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
||||||
|
def test_config_validation_with_valid_inputs(cfg):
|
||||||
|
"""Test that valid sequence parallelism configurations pass validation."""
|
||||||
|
# Import the actual model class with appropriate mocks
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||||
|
cfg = cfg | {
|
||||||
|
"sequence_parallel_degree": 2,
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should validate without errors
|
||||||
|
config = AxolotlInputConfig(**cfg)
|
||||||
|
assert config.sequence_parallel_degree == 2
|
||||||
|
assert config.flash_attention is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_validation_with_invalid_inputs(cfg):
|
||||||
|
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||||
|
cfg = cfg | {
|
||||||
|
"sequence_parallel_degree": 2,
|
||||||
|
"flash_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should raise ValidationError
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
AxolotlInputConfig(**cfg)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert "flash_attention: true must be set" in str(excinfo.value)
|
||||||
@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
|
"base_model": "huggyllama/llama-7b",
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"dataset_exact_deduplication": True,
|
"dataset_exact_deduplication": True,
|
||||||
@@ -282,6 +284,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
normalize_config(self.cfg_1)
|
||||||
|
|
||||||
def test_prepare_dataset_with_deduplication_train(self):
|
def test_prepare_dataset_with_deduplication_train(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
|
|||||||
Reference in New Issue
Block a user