progress on ring attn impl

This commit is contained in:
Dan Saunders
2025-03-04 21:42:34 +00:00
parent bd952de9d2
commit dce61cdab1
7 changed files with 33 additions and 21 deletions

View File

@@ -849,7 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1: if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None

View File

@@ -10,7 +10,6 @@ import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from typing import Any, Dict, Literal, Optional from typing import Any, Dict, Literal, Optional
from typing_extensions import override
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -25,6 +24,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.integrations.base import BaseOptimizerFactory from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -803,7 +803,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
@override @override
def training_step( def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch=None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
@@ -827,7 +830,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if "attention_mask" in inputs: if "attention_mask" in inputs:
# Calculate sequence lengths from attention mask # Calculate sequence lengths from attention mask
seq_lens = inputs["attention_mask"].sum(dim=1).tolist() seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1] total_seq_len = (
inputs["attention_mask"].shape[0]
* inputs["attention_mask"].shape[1]
)
else: else:
# Assume all sequences are the same length if no mask is provided # Assume all sequences are the same length if no mask is provided
batch_size = inputs["input_ids"].shape[0] batch_size = inputs["input_ids"].shape[0]
@@ -845,11 +851,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
the substituted ring_flash_attn. the substituted ring_flash_attn.
""" """
cu_seqlens = torch.cumsum( cu_seqlens = torch.cumsum(
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1, dim=-1,
dtype=torch.int32, dtype=torch.int32,
) )
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -209,9 +209,7 @@ class AxolotlTrainingMixins:
sequence_parallel_size: Optional[int] = field( sequence_parallel_size: Optional[int] = field(
default=1, default=1,
metadata={ metadata={"help": "The number of workers to use in sequence parallelism"},
"help": "The number of workers to use in sequence parallelism"
},
) )

View File

@@ -11,7 +11,6 @@ from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict import addict
from axolotl.utils.ring_attn import register_ring_attn
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
@@ -67,6 +66,7 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.ring_attn import register_ring_attn
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")

View File

@@ -22,9 +22,10 @@ def register_ring_attn(sequence_parallel_size):
return return
world_size = dist.get_world_size() world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, \ assert world_size % sequence_parallel_size == 0, (
f"sequence_parallel_size ({sequence_parallel_size}) " \ f"sequence_parallel_size ({sequence_parallel_size}) "
f"must evenly divide world_size ({world_size})" f"must evenly divide world_size ({world_size})"
)
for i in range(world_size // sequence_parallel_size): for i in range(world_size // sequence_parallel_size):
ring_attn_ranks = list( ring_attn_ranks = list(

View File

@@ -8,7 +8,6 @@ from contextlib import contextmanager
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
import numpy as np import numpy as np
import torch import torch
import torch.cuda import torch.cuda
@@ -18,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths