diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 97de5a92d..a1b2ac27a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -761,7 +761,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "kd_top_k_before_softmax" ] = self.cfg.kd_top_k_before_softmax - + training_arguments_kwargs[ "sequence_parallel_size" ] = self.cfg.sequence_parallel_size @@ -797,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.reward_model: data_collator_kwargs["max_length"] = self.cfg.sequence_len - + trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls @@ -849,7 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): if training_args.pretraining: - if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1: + if ( + self.cfg.pretraining_sample_concatenation is False + or self.cfg.micro_batch_size > 1 + ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 35cf2f30a..4a1ba5a02 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -10,7 +10,6 @@ import os from collections import defaultdict from functools import wraps from typing import Any, Dict, Literal, Optional -from typing_extensions import override import torch import torch.nn.functional as F @@ -25,6 +24,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length +from typing_extensions import override from axolotl.integrations.base import BaseOptimizerFactory from axolotl.monkeypatch.relora import ReLoRAScheduler @@ -800,17 +800,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) - + @override def training_step( - self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None + self, + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + num_items_in_batch=None, ) -> torch.Tensor: """ Perform a training step on a batch of inputs. Note: we are subclassing `transformers.trainer.Trainer` in order to compute parameters needed for the ring flash attention implementation we're using. - + Args: model (`nn.Module`): The model to train. @@ -827,7 +830,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if "attention_mask" in inputs: # Calculate sequence lengths from attention mask seq_lens = inputs["attention_mask"].sum(dim=1).tolist() - total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1] + total_seq_len = ( + inputs["attention_mask"].shape[0] + * inputs["attention_mask"].shape[1] + ) else: # Assume all sequences are the same length if no mask is provided batch_size = inputs["input_ids"].shape[0] @@ -838,18 +844,22 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): self._update_ring_flash_attn_params(seq_lens, total_seq_len) return super().training_step(model, inputs, num_items_in_batch) - + def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len): """ Calculate the cu_seqlens for the current forward pass and pass the value to the substituted ring_flash_attn. """ cu_seqlens = torch.cumsum( - torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), + torch.tensor( + packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 + ), dim=-1, dtype=torch.int32, ) - cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) + cu_seqlens = F.pad( + F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len + ) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 2d3b0dda1..57ad638d6 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -206,12 +206,10 @@ class AxolotlTrainingMixins: "help": "Whether to apply top_k_before_softmax to the logits when using KD" }, ) - + sequence_parallel_size: Optional[int] = field( default=1, - metadata={ - "help": "The number of workers to use in sequence parallelism" - }, + metadata={"help": "The number of workers to use in sequence parallelism"}, ) diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py index 38d9f7f00..927671909 100644 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py @@ -749,6 +749,6 @@ def forward( def apply_dist_flash_attn_monkey_patch_llama(): initialize_distributed() - + LlamaModel.forward = forward LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 173a1b99f..461324f32 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,7 +11,6 @@ from functools import cached_property from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import addict -from axolotl.utils.ring_attn import register_ring_attn import bitsandbytes as bnb import torch import transformers @@ -67,6 +66,7 @@ from axolotl.utils.distributed import ( from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant +from axolotl.utils.ring_attn import register_ring_attn LOG = logging.getLogger("axolotl") @@ -558,7 +558,7 @@ class ModelLoader: # # Apply the monkey patch # apply_seq_parallel_monkey_patch(method, model_type) - + register_ring_attn(self.cfg.sequence_parallel_size) def patch_attention(self) -> None: diff --git a/src/axolotl/utils/ring_attn.py b/src/axolotl/utils/ring_attn.py index 09a94b9fa..ddd70cf98 100644 --- a/src/axolotl/utils/ring_attn.py +++ b/src/axolotl/utils/ring_attn.py @@ -22,9 +22,10 @@ def register_ring_attn(sequence_parallel_size): return world_size = dist.get_world_size() - assert world_size % sequence_parallel_size == 0, \ - f"sequence_parallel_size ({sequence_parallel_size}) " \ + assert world_size % sequence_parallel_size == 0, ( + f"sequence_parallel_size ({sequence_parallel_size}) " f"must evenly divide world_size ({world_size})" + ) for i in range(world_size // sequence_parallel_size): ring_attn_ranks = list( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d9fbbbc40..501bf3f68 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -8,7 +8,6 @@ from contextlib import contextmanager from functools import partial from typing import List, Optional -from axolotl.integrations.easy_context import prepare_seq_parallel_inputs import numpy as np import torch import torch.cuda @@ -18,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.integrations.easy_context import prepare_seq_parallel_inputs from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths