progress on ring attn impl

This commit is contained in:
Dan Saunders
2025-03-04 21:42:34 +00:00
parent bd952de9d2
commit dce61cdab1
7 changed files with 33 additions and 21 deletions

View File

@@ -761,7 +761,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs[
"sequence_parallel_size"
] = self.cfg.sequence_parallel_size
@@ -797,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
@@ -849,7 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None

View File

@@ -10,7 +10,6 @@ import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional
from typing_extensions import override
import torch
import torch.nn.functional as F
@@ -25,6 +24,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -800,17 +800,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
@override
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch=None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
parameters needed for the ring flash attention implementation we're using.
Args:
model (`nn.Module`):
The model to train.
@@ -827,7 +830,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if "attention_mask" in inputs:
# Calculate sequence lengths from attention mask
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
total_seq_len = (
inputs["attention_mask"].shape[0]
* inputs["attention_mask"].shape[1]
)
else:
# Assume all sequences are the same length if no mask is provided
batch_size = inputs["input_ids"].shape[0]
@@ -838,18 +844,22 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
return super().training_step(model, inputs, num_items_in_batch)
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
"""
cu_seqlens = torch.cumsum(
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -206,12 +206,10 @@ class AxolotlTrainingMixins:
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
sequence_parallel_size: Optional[int] = field(
default=1,
metadata={
"help": "The number of workers to use in sequence parallelism"
},
metadata={"help": "The number of workers to use in sequence parallelism"},
)

View File

@@ -749,6 +749,6 @@ def forward(
def apply_dist_flash_attn_monkey_patch_llama():
initialize_distributed()
LlamaModel.forward = forward
LlamaDecoderLayer.forward = llama_layer_forward

View File

@@ -11,7 +11,6 @@ from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
from axolotl.utils.ring_attn import register_ring_attn
import bitsandbytes as bnb
import torch
import transformers
@@ -67,6 +66,7 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.ring_attn import register_ring_attn
LOG = logging.getLogger("axolotl")
@@ -558,7 +558,7 @@ class ModelLoader:
# # Apply the monkey patch
# apply_seq_parallel_monkey_patch(method, model_type)
register_ring_attn(self.cfg.sequence_parallel_size)
def patch_attention(self) -> None:

View File

@@ -22,9 +22,10 @@ def register_ring_attn(sequence_parallel_size):
return
world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, \
f"sequence_parallel_size ({sequence_parallel_size}) " \
assert world_size % sequence_parallel_size == 0, (
f"sequence_parallel_size ({sequence_parallel_size}) "
f"must evenly divide world_size ({world_size})"
)
for i in range(world_size // sequence_parallel_size):
ring_attn_ranks = list(

View File

@@ -8,7 +8,6 @@ from contextlib import contextmanager
from functools import partial
from typing import List, Optional
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
import numpy as np
import torch
import torch.cuda
@@ -18,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths