progress on ring attn impl
This commit is contained in:
@@ -761,7 +761,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"kd_top_k_before_softmax"
|
||||
] = self.cfg.kd_top_k_before_softmax
|
||||
|
||||
|
||||
training_arguments_kwargs[
|
||||
"sequence_parallel_size"
|
||||
] = self.cfg.sequence_parallel_size
|
||||
@@ -797,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
@@ -849,7 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -25,6 +24,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
@@ -800,17 +800,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
@override
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
|
||||
parameters needed for the ring flash attention implementation we're using.
|
||||
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to train.
|
||||
@@ -827,7 +830,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
if "attention_mask" in inputs:
|
||||
# Calculate sequence lengths from attention mask
|
||||
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
||||
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
|
||||
total_seq_len = (
|
||||
inputs["attention_mask"].shape[0]
|
||||
* inputs["attention_mask"].shape[1]
|
||||
)
|
||||
else:
|
||||
# Assume all sequences are the same length if no mask is provided
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
@@ -838,18 +844,22 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
|
||||
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn.
|
||||
"""
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
|
||||
@@ -206,12 +206,10 @@ class AxolotlTrainingMixins:
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
sequence_parallel_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "The number of workers to use in sequence parallelism"
|
||||
},
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -749,6 +749,6 @@ def forward(
|
||||
|
||||
def apply_dist_flash_attn_monkey_patch_llama():
|
||||
initialize_distributed()
|
||||
|
||||
|
||||
LlamaModel.forward = forward
|
||||
LlamaDecoderLayer.forward = llama_layer_forward
|
||||
|
||||
@@ -11,7 +11,6 @@ from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
from axolotl.utils.ring_attn import register_ring_attn
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
@@ -67,6 +66,7 @@ from axolotl.utils.distributed import (
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
from axolotl.utils.ring_attn import register_ring_attn
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -558,7 +558,7 @@ class ModelLoader:
|
||||
|
||||
# # Apply the monkey patch
|
||||
# apply_seq_parallel_monkey_patch(method, model_type)
|
||||
|
||||
|
||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
|
||||
@@ -22,9 +22,10 @@ def register_ring_attn(sequence_parallel_size):
|
||||
return
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert world_size % sequence_parallel_size == 0, \
|
||||
f"sequence_parallel_size ({sequence_parallel_size}) " \
|
||||
assert world_size % sequence_parallel_size == 0, (
|
||||
f"sequence_parallel_size ({sequence_parallel_size}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
for i in range(world_size // sequence_parallel_size):
|
||||
ring_attn_ranks = list(
|
||||
|
||||
@@ -8,7 +8,6 @@ from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
@@ -18,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
Reference in New Issue
Block a user