Compare commits

..

22 Commits

Author SHA1 Message Date
Wing Lian
b708a1cc45 validate config to set defaults 2025-04-26 13:11:25 -04:00
Rahul Tuli
daa9a58f83 Add: line about further optimizations using llmcompressor
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-24 14:06:25 -04:00
Rahul Tuli
ae7069e15b Merge branch 'main' into llmcompressor-sft 2025-04-24 12:37:14 -05:00
Rahul Tuli
20d48cd617 Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-24 13:36:09 -04:00
Rahul Tuli
e766a730ba Add: .qmd file 2025-04-24 12:45:57 -04:00
Rahul Tuli
7dc797860e Tests, Style, Updates 2025-04-24 12:45:57 -04:00
Rahul Tuli
ff4904c8c4 Rebase and updates! 2025-04-24 12:45:57 -04:00
Rahul Tuli
45b7293793 Add: llm_compressor integration documentation 2025-04-24 12:45:57 -04:00
Rahul Tuli
279c7178bc Move: LLMCompressorPlugin into it's own submodule 2025-04-24 12:45:57 -04:00
Rahul Tuli
e73c3709f9 Update model config 2025-04-24 12:45:57 -04:00
Rahul Tuli
33562189f8 Use: absolute import 2025-04-24 12:45:57 -04:00
Rahul Tuli
c057a2268f Rename: sft.yaml to sparse-finetuning.yaml 2025-04-24 12:45:57 -04:00
Rahul Tuli
9d7a3809b5 Add: llcompressor installable 2025-04-24 12:45:57 -04:00
Rahul Tuli
b7b24d6a64 Address review comments from @markurtz 2025-04-24 12:45:57 -04:00
Rahul Tuli
8b82b8f7a1 Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
2025-04-24 12:45:57 -04:00
Rahul Tuli
81da58c0a1 Update llmcompressor version to latest 2025-04-24 12:45:57 -04:00
Rahul Tuli
2cd5a234a7 Revert: TODO's 2025-04-24 12:45:57 -04:00
Rahul Tuli
8c1af0747d Use: warning over warn 2025-04-24 12:45:57 -04:00
Rahul Tuli
a06b360d99 pre commit hooks 2025-04-24 12:45:57 -04:00
Rahul Tuli
0f6456a14f Add:llmcompressor instalable 2025-04-24 12:45:57 -04:00
Rahul Tuli
47a333ce49 Update: review comments! 2025-04-24 12:45:57 -04:00
Rahul Tuli
f9d6776c28 Add: SFTPlugin with llmcompressor 2025-04-24 12:45:57 -04:00
30 changed files with 259 additions and 843 deletions

View File

@@ -24,7 +24,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras: vllm
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

View File

@@ -8,7 +8,6 @@ on:
- 'setup.py' - 'setup.py'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml' - '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -43,7 +42,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126 - cuda: 126

View File

@@ -258,12 +258,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -275,7 +269,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: vllm
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"

View File

@@ -52,4 +52,4 @@ pytest -v --durations=10 \
--cov-append \ --cov-append \
--cov-report=xml:e2e-coverage.xml --cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}

View File

@@ -20,4 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -1,7 +1,5 @@
codecov: codecov:
require_ci_to_pass: yes require_ci_to_pass: yes
notify:
wait_for_ci: true
coverage: coverage:
precision: 2 precision: 2

View File

@@ -28,8 +28,6 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples: Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0` - `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1` - `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1` - `main-base-py3.11-cu124-2.4.1`
@@ -52,7 +50,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main # on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version} main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4) # latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest main-latest
# nightly build # nightly build
@@ -70,7 +68,6 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples: Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0` - `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1` - `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1` - `main-py3.11-cu124-2.4.1`

View File

@@ -10,6 +10,7 @@ plugins:
liger_glu_activation: true liger_glu_activation: true
liger_rms_norm: true liger_rms_norm: true
liger_layer_norm: true liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true load_in_4bit: true

View File

@@ -11,13 +11,13 @@ liger-kernel==0.5.8
packaging==23.2 packaging==23.2
peft==0.15.2 peft==0.15.1
transformers==4.51.3 transformers==4.51.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.17.0 trl==0.16.1
hf_xet==1.0.0 hf_xet==1.0.0
hqq==0.2.5 hqq==0.2.5

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7): if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0 # _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.4"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append( _install_requires.append(
"xformers==0.0.29.post2" "xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6 ) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.4"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:

View File

@@ -932,6 +932,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt" kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator( return collator(
*collator_args, *collator_args,
@@ -1048,9 +1051,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None training_args_cls = None
blocklist_args_kwargs = [] blocklist_args_kwargs = []
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
@@ -1121,12 +1121,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs, **training_args_kwargs,
) )
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args return training_args
def build(self, total_num_steps): def build(self, total_num_steps):

View File

@@ -371,15 +371,13 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
loss = super().compute_loss( return super().compute_loss(
model, model,
inputs, inputs,
return_outputs=return_outputs, return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
return loss
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {} concatenated_batch = {}

View File

@@ -135,9 +135,7 @@ class GRPOStrategy:
try: try:
# use importlib to dynamically load the reward function from the module # use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1] reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module( reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
".".join(reward_func_fqn.split(".")[:-1])
)
reward_func = getattr(reward_func_module, reward_func_module_name) reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2: if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError( raise ValueError(

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,86 +1,16 @@
""" """Module for Axolotl trainer sequence parallelism mixin"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
import functools
import logging import logging
import torch
import torch.distributed as dist import torch.distributed as dist
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin: class SequenceParallelMixin:
""" """
Mixin class for sequence parallelism support in trainers. Mixin class for sequence parallelism support in trainers.
@@ -157,157 +87,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler( return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True eval_dataset, shuffle=False, is_eval=True
) )
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -27,6 +27,8 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml ```yaml
plugins: plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
``` ```
## Supported Models ## Supported Models

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy. Input args for Cut Cross Entropy.
""" """
cut_cross_entropy: Optional[bool] = True cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -6,7 +6,6 @@ import os
import signal import signal
import sys import sys
import weakref import weakref
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
@@ -26,9 +25,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
@@ -189,28 +185,16 @@ def execute_training(
trainer: The configured trainer object. trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable. resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
""" """
# Define the context managers to use LOG.info("Starting trainer...")
flash_context = ( if cfg.flash_optimum:
torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True, enable_flash=True,
enable_math=True, enable_math=True,
enable_mem_efficient=True, enable_mem_efficient=True,
) ):
if cfg.flash_optimum trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else nullcontext() else:
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
@@ -287,19 +271,7 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif cfg.local_rank == 0: elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import ( from axolotl.integrations.llm_compressor.utils import (
save_compressed_model, save_compressed_model,
) )
@@ -312,6 +284,17 @@ def save_trained_model(
save_compressed=cfg.llmcompressor.save_compressed, save_compressed=cfg.llmcompressor.save_compressed,
) )
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer): def create_model_card(cfg: DictDefault, trainer: Trainer):
""" """

View File

@@ -1,12 +1,20 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences""" """
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import numpy as np import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass @dataclass
class DataCollatorForSeq2Seq: class DataCollatorForSeq2Seq:
@@ -41,6 +49,8 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`): return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf". The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
""" """
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
@@ -51,6 +61,17 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys() has_attn_mask = "attention_mask" in features[0].keys()
@@ -120,8 +141,62 @@ class DataCollatorForSeq2Seq:
) )
features["decoder_input_ids"] = decoder_input_ids features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
for key in batch:
if batch[key].size(1) == total_seq_len:
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
@dataclass @dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -126,6 +126,9 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f: with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f) cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch: if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step if save_steps < 1.0: # prevent saves on every step

View File

@@ -134,9 +134,10 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
"csv", data_files=f.name, split="train", streaming=True "csv", data_files=f.name, split="train", streaming=True
) )
else: else:
iter_ds = load_dataset( if is_local_main_process():
path, streaming=True, split=split, name=name, data_files=data_files iter_ds = load_dataset(
) path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")

View File

@@ -1,7 +1,5 @@
"""custom checkpointing utils""" """custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import ( from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer, Unsloth_Offloaded_Gradient_Checkpointer,
) )
@@ -11,10 +9,6 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply( return Unsloth_Offloaded_Gradient_Checkpointer.apply(
( decoder_layer.__self__,
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args, *args,
) )

View File

@@ -1149,17 +1149,22 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="after") @field_validator("sequence_parallel_degree", mode="after")
def check_sequence_parallel_degree(self): @classmethod
if not self.sequence_parallel_degree: def check_sequence_parallel_degree(cls, value, info):
self.sequence_parallel_degree = 1 if not value:
elif self.sequence_parallel_degree > 1: value = 1
if not self.flash_attention:
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
if self.sample_packing and self.micro_batch_size > 1: if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError( raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled" "micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"
@@ -1179,40 +1184,42 @@ class AxolotlInputConfig(
# according to the proportion of non-padding tokens per rank. # according to the proportion of non-padding tokens per rank.
LOG.warning( LOG.warning(
"Sequence parallelism (SP) is enabled with " "Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. " f"sequence_parallel_degree={value}. Please note that logged losses may "
"Please note that logged losses may differ slightly to the non-SP " "differ slightly to the non-SP losses due to transformers Trainer "
"losses due to transformers Trainer implementation details. " "implementation details. Please see "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details." "for more details."
) )
return self return value
@model_validator(mode="after") @field_validator("ring_attn_func", mode="after")
def validate_ring_attn_func(self): @classmethod
if getattr(self, "sequence_parallel_degree", 1) == 1: def check_ring_attn_func(cls, value, info):
return self if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if self.ring_attn_func is not None: if value is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc) valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs: if value in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func) value = RingAttnFunc(value)
else: else:
raise ValueError( raise ValueError(
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}" f"ring_attn_func: {value} must be one of {valid_funcs}"
) )
else: else:
# Default ring attention function selection # Default ring attention function selection
sample_packing = getattr(self, "sample_packing", False) sample_packing = info.data.get("sample_packing")
self.ring_attn_func = ( value = (
RingAttnFunc.VARLEN_LLAMA3 RingAttnFunc.VARLEN_LLAMA3
if sample_packing if sample_packing
else RingAttnFunc.BATCH_RING else RingAttnFunc.BATCH_RING
) )
return self return value
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing: elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
if cfg.eval_sample_packing: if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,
@@ -528,13 +528,6 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None): def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage: if stage:

View File

@@ -12,11 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import ( from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [ MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed", "nm-testing/llama2.c-stories42M-pruned2.4-compressed",
@@ -30,7 +26,6 @@ MODELS = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"] "save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
) )
@require_llmcompressor
class TestLLMCompressorIntegration: class TestLLMCompressorIntegration:
""" """
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
@@ -95,9 +90,12 @@ class TestLLMCompressorIntegration:
def _check_llmcompressor_model_outputs(temp_dir, save_compressed): def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
# recipe.yaml should exist
assert (Path(temp_dir) / "recipe.yaml").exists()
# sparsity config exists if save_compressed
if save_compressed:
from compressed_tensors import ModelCompressor from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig from compressed_tensors.config import Sparse24BitMaskConfig

View File

@@ -4,14 +4,11 @@ GRPO test suite
import os import os
import random import random
import shutil
import subprocess # nosec B404 import subprocess # nosec B404
import sys import sys
import tempfile
import time import time
from pathlib import Path from pathlib import Path
import psutil
import pytest import pytest
import requests import requests
import yaml import yaml
@@ -24,8 +21,8 @@ from tests.e2e.utils import require_vllm
def start_vllm( def start_vllm(
model: str, env: dict, wait: int | None = None, quiet=False, **kwargs model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> subprocess.Popen: ) -> int:
""" """
helper function to start the VLLM server in the background, mostly for testing purposes helper function to start the VLLM server in the background, mostly for testing purposes
""" """
@@ -49,41 +46,10 @@ def start_vllm(
# print out the command to be executed # print out the command to be executed
print(" ".join(cmd)) print(" ".join(cmd))
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
temp_file.write(
"""{
"formatters": {
"json": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
}
},
"handlers": {
"file": {
"class": "logging.FileHandler",
"formatter": "json",
"level": "DEBUG",
"filename": "/tmp/vllm.log",
"mode": "a"
}
},
"loggers": {
"vllm": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": false
}
},
"version": 1
}"""
)
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id # start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with process = subprocess.Popen( # pylint: disable=consider-using-with
cmd, cmd,
env=cmd_env, env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE, stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE, stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603 ) # nosec B603
@@ -92,51 +58,32 @@ def start_vllm(
print(f"VLLM server process started (PID: {process.pid})") print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds # wait until the http server is ready, even if it 404s, but timeout after 60 seconds
period_seconds = 5
started = False started = False
if wait and host and port: if wait and host and port:
for i in range(0, int(wait), period_seconds): for _ in range(int(wait)):
try: try:
response = requests.get(f"http://{host}:{port}", timeout=1) response = requests.get(f"http://{host}:{port}", timeout=1)
print(f"{i}: VLLM server (status: {response.status_code})")
if int(response.status_code) in [200, 404]: if int(response.status_code) in [200, 404]:
started = True started = True
break break
except requests.exceptions.RequestException as exc: except requests.exceptions.RequestException:
print(f"{i}: VLLM server failed to start: {str(exc)}") pass
# also check if the process.pid is still running # also check if the process.pid is still running
if not process.poll() is None: if not process.poll() is None:
break break
time.sleep(period_seconds) time.sleep(1)
if wait and not started: if wait and not started:
print( print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs." f"VLLM server process did not start within {wait} seconds. Please check your server logs."
) )
recursive_kill(process) process.kill()
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process # return the process id
return process return process.pid
def recursive_kill(process: subprocess.Popen):
"""
Recursively kill a process and its children
"""
process = psutil.Process(process.pid)
for child in psutil.Process(process.pid).children(recursive=True):
child.terminate()
child.kill()
os.kill(child.pid, 9)
process.terminate()
process.kill()
os.kill(process.pid, 9)
class TestGRPO: class TestGRPO:
@@ -227,17 +174,16 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "NVL", "NCCL_P2P_LEVEL": "LOC",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1", "VLLM_USE_V1": "0",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=300, wait=120,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -256,14 +202,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={ env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
recursive_kill(vllm_process) os.kill(vllm_process_id, 9)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
@@ -320,17 +262,16 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1", "VLLM_USE_V1": "0",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=300, wait=120,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -349,11 +290,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={ env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
recursive_kill(vllm_process) os.kill(vllm_process_id, 9)

View File

@@ -1,77 +0,0 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -99,7 +99,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -2,19 +2,14 @@
# pylint: disable=redefined-outer-name,unused-argument # pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from accelerate.state import PartialState from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group, get_ring_attn_group,
register_ring_attn,
set_ring_attn_group, set_ring_attn_group,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -52,27 +47,6 @@ def fixture_cfg():
return cfg return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention: class TestRingAttention:
"""Tests for the ring attention functionality.""" """Tests for the ring attention functionality."""
@@ -99,6 +73,11 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state self, mock_world_size, mock_rank, mock_new_group, partial_state
): ):
"""Test that ring attention groups are created correctly.""" """Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks # Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3 mock_rank.return_value = 3 # GPU #3
@@ -122,303 +101,88 @@ class TestRingAttention:
set_ring_attn_group(None) set_ring_attn_group(None)
class TestConfigValidation: # Mock a simplified DataCollator test
"""Tests for validating sequence parallelism configurations.""" @patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
@pytest.fixture(autouse=True) # Create a sample batch
def setup_mocks(self, monkeypatch): batch = {
"""Set up mocks for all tests in this class.""" "input_ids": torch.tensor(
# Mock the ring_flash_attn module [
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
@pytest.fixture # Simplified slicing logic from SequenceParallelDataCollator
def base_cfg(self): def slice_batch(batch, rank, world_size):
"""Create a base configuration for testing.""" result = {}
return DictDefault( for key in batch:
{ seq_len = batch[key].shape[1]
"base_model": "HuggingFaceTB/SmolLM2-135M", slice_size = seq_len // world_size
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], start_idx = rank * slice_size
"micro_batch_size": 1, end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
"gradient_accumulation_steps": 1, result[key] = batch[key][:, start_idx:end_idx]
"learning_rate": 1e-3, return result
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize( # Slice the batch
"config_updates, expected_values, should_pass, error_msg", result = slice_batch(
[ batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
) )
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config # Check slicing
cfg = base_cfg assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
cfg.update(config_updates) expected_input_ids = torch.tensor(
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[ [
(None, True, RingAttnFunc.VARLEN_LLAMA3), [104, 105, 106], # Second slice of first sequence
(None, False, RingAttnFunc.BATCH_RING), [204, 205, 206], # Second slice of second sequence
], ]
ids=["default_with_sample_packing", "default_without_sample_packing"],
) )
def test_ring_attn_func_validation( assert torch.all(result["input_ids"] == expected_input_ids)
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
class TestApplySequenceParallelism: @patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
"""Tests for the apply_sequence_parallelism function.""" def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
@pytest.fixture(autouse=True) # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
def mock_distributed(self, monkeypatch): cfg = cfg | {
"""Mock torch.distributed functions for testing.""" "sequence_parallel_degree": 2,
# Mock is_initialized to return True "flash_attention": True,
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) }
# Mock get_rank to return 0 by default # Should validate without errors
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0) config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group def test_config_validation_with_invalid_inputs(cfg):
monkeypatch.setattr( """Test that invalid sequence parallelism configurations fail validation."""
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", from axolotl.utils.schemas.config import AxolotlInputConfig
MagicMock,
)
# Mock update_ring_attn_params # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
monkeypatch.setattr( cfg = cfg | {
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", "sequence_parallel_degree": 2,
lambda **kwargs: None, "flash_attention": False,
) }
def test_world_size_one(self, sequence_parallel_batch): # Should raise ValidationError
"""Test that function returns original batch when world size is 1.""" with pytest.raises(ValueError) as excinfo:
result = apply_sequence_parallelism( AxolotlInputConfig(**cfg)
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should return the original batch unchanged # Verify error message
assert result == sequence_parallel_batch assert "flash_attention: true must be set" in str(excinfo.value)
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -109,24 +109,6 @@ def require_vllm(test_case):
)(test_case) )(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires a llmcompressor to be installed"
)(test_case)
def is_hopper(): def is_hopper():
compute_capability = torch.cuda.get_device_capability() compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0) return compute_capability == (9, 0)