progress on ring attn impl
This commit is contained in:
@@ -761,7 +761,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"kd_top_k_before_softmax"
|
"kd_top_k_before_softmax"
|
||||||
] = self.cfg.kd_top_k_before_softmax
|
] = self.cfg.kd_top_k_before_softmax
|
||||||
|
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sequence_parallel_size"
|
"sequence_parallel_size"
|
||||||
] = self.cfg.sequence_parallel_size
|
] = self.cfg.sequence_parallel_size
|
||||||
@@ -797,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
trainer_kwargs, trainer_cls
|
trainer_kwargs, trainer_cls
|
||||||
@@ -849,7 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
|
if (
|
||||||
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
|
or self.cfg.micro_batch_size > 1
|
||||||
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import os
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -25,6 +24,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
@@ -800,17 +800,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def training_step(
|
def training_step(
|
||||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
num_items_in_batch=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
|
|
||||||
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
|
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
|
||||||
parameters needed for the ring flash attention implementation we're using.
|
parameters needed for the ring flash attention implementation we're using.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (`nn.Module`):
|
model (`nn.Module`):
|
||||||
The model to train.
|
The model to train.
|
||||||
@@ -827,7 +830,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
if "attention_mask" in inputs:
|
if "attention_mask" in inputs:
|
||||||
# Calculate sequence lengths from attention mask
|
# Calculate sequence lengths from attention mask
|
||||||
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
||||||
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
|
total_seq_len = (
|
||||||
|
inputs["attention_mask"].shape[0]
|
||||||
|
* inputs["attention_mask"].shape[1]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Assume all sequences are the same length if no mask is provided
|
# Assume all sequences are the same length if no mask is provided
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
@@ -838,18 +844,22 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
|
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
|
||||||
|
|
||||||
return super().training_step(model, inputs, num_items_in_batch)
|
return super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
||||||
"""
|
"""
|
||||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
the substituted ring_flash_attn.
|
the substituted ring_flash_attn.
|
||||||
"""
|
"""
|
||||||
cu_seqlens = torch.cumsum(
|
cu_seqlens = torch.cumsum(
|
||||||
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
|||||||
@@ -206,12 +206,10 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_parallel_size: Optional[int] = field(
|
sequence_parallel_size: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
"help": "The number of workers to use in sequence parallelism"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -749,6 +749,6 @@ def forward(
|
|||||||
|
|
||||||
def apply_dist_flash_attn_monkey_patch_llama():
|
def apply_dist_flash_attn_monkey_patch_llama():
|
||||||
initialize_distributed()
|
initialize_distributed()
|
||||||
|
|
||||||
LlamaModel.forward = forward
|
LlamaModel.forward = forward
|
||||||
LlamaDecoderLayer.forward = llama_layer_forward
|
LlamaDecoderLayer.forward = llama_layer_forward
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from functools import cached_property
|
|||||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
from axolotl.utils.ring_attn import register_ring_attn
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -67,6 +66,7 @@ from axolotl.utils.distributed import (
|
|||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
from axolotl.utils.ring_attn import register_ring_attn
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -558,7 +558,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
# # Apply the monkey patch
|
# # Apply the monkey patch
|
||||||
# apply_seq_parallel_monkey_patch(method, model_type)
|
# apply_seq_parallel_monkey_patch(method, model_type)
|
||||||
|
|
||||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ def register_ring_attn(sequence_parallel_size):
|
|||||||
return
|
return
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
assert world_size % sequence_parallel_size == 0, \
|
assert world_size % sequence_parallel_size == 0, (
|
||||||
f"sequence_parallel_size ({sequence_parallel_size}) " \
|
f"sequence_parallel_size ({sequence_parallel_size}) "
|
||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(world_size // sequence_parallel_size):
|
for i in range(world_size // sequence_parallel_size):
|
||||||
ring_attn_ranks = list(
|
ring_attn_ranks = list(
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from contextlib import contextmanager
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
@@ -18,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|||||||
Reference in New Issue
Block a user