progress on ring attn impl
This commit is contained in:
@@ -849,7 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
|
if (
|
||||||
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
|
or self.cfg.micro_batch_size > 1
|
||||||
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import os
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -25,6 +24,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
@@ -803,7 +803,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def training_step(
|
def training_step(
|
||||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
num_items_in_batch=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
@@ -827,7 +830,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
if "attention_mask" in inputs:
|
if "attention_mask" in inputs:
|
||||||
# Calculate sequence lengths from attention mask
|
# Calculate sequence lengths from attention mask
|
||||||
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
||||||
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
|
total_seq_len = (
|
||||||
|
inputs["attention_mask"].shape[0]
|
||||||
|
* inputs["attention_mask"].shape[1]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Assume all sequences are the same length if no mask is provided
|
# Assume all sequences are the same length if no mask is provided
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
@@ -845,11 +851,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
the substituted ring_flash_attn.
|
the substituted ring_flash_attn.
|
||||||
"""
|
"""
|
||||||
cu_seqlens = torch.cumsum(
|
cu_seqlens = torch.cumsum(
|
||||||
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
|||||||
@@ -209,9 +209,7 @@ class AxolotlTrainingMixins:
|
|||||||
|
|
||||||
sequence_parallel_size: Optional[int] = field(
|
sequence_parallel_size: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
"help": "The number of workers to use in sequence parallelism"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from functools import cached_property
|
|||||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
from axolotl.utils.ring_attn import register_ring_attn
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -67,6 +66,7 @@ from axolotl.utils.distributed import (
|
|||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
from axolotl.utils.ring_attn import register_ring_attn
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ def register_ring_attn(sequence_parallel_size):
|
|||||||
return
|
return
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
assert world_size % sequence_parallel_size == 0, \
|
assert world_size % sequence_parallel_size == 0, (
|
||||||
f"sequence_parallel_size ({sequence_parallel_size}) " \
|
f"sequence_parallel_size ({sequence_parallel_size}) "
|
||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(world_size // sequence_parallel_size):
|
for i in range(world_size // sequence_parallel_size):
|
||||||
ring_attn_ranks = list(
|
ring_attn_ranks = list(
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from contextlib import contextmanager
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
@@ -18,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|||||||
Reference in New Issue
Block a user