diff --git a/setup.py b/setup.py index d19c14828..aef8182af 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,11 @@ def get_package_version(): extras_require = { "flash-attn": ["flash-attn==2.7.4.post1"], - "ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"], + "ring-flash-attn": [ + "flash-attn==2.7.4.post1", + "ring-flash-attn>=0.1.4", + "yunchang==0.6.0", + ], "deepspeed": [ "deepspeed==0.16.4", "deepspeed-kernels", diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 436d3a073..9fed78eb7 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -8,12 +8,11 @@ import logging import os from collections import defaultdict from functools import wraps -from typing import Any, Literal +from typing import Literal import datasets import torch from datasets import Dataset -from torch import nn from torch.utils.data import ( BatchSampler, DataLoader, @@ -593,27 +592,3 @@ class AxolotlTrainer( output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) - - def training_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - num_items_in_batch: int | None = None, - ) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform training step for. - inputs: Dictionary mapping. - """ - # Set up sequence parallelism for this step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal training step - loss = super().training_step(model, inputs, num_items_in_batch) - - return loss diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index f52c044b6..9bcd5db57 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F from datasets import Dataset +from torch import nn from torch.utils.data import DistributedSampler, Sampler from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group @@ -129,3 +130,53 @@ class SequenceParallelMixin: ) update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) + + def training_step( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + num_items_in_batch: int | None = None, + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. Overrides the + `transformers.trainer.Trainer` method to handle sequence parallelism if + enabled. + + Args: + model: Model to perform training step for. + inputs: Dictionary mapping. + """ + # Set up sequence parallelism for this step if enabled + if self.args.sequence_parallel_degree > 1: + self._update_ring_flash_attn_params(inputs) + + # Proceed with normal training step + return super().training_step(model, inputs, num_items_in_batch) # type: ignore + + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """ + Perform a prediction step on a batch of inputs. Overrides the + `transformers.trainer.Trainer` method to handle sequence parallelism if + enabled. + + Args: + model: Model to perform prediction step for. + inputs: Dictionary mapping of inputs. + prediction_loss_only: Whether to return only the loss. + ignore_keys: Keys to ignore in the inputs. + + Returns: + Tuple of (loss, logits, labels). + """ + # Set up sequence parallelism for this prediction step if enabled + if self.args.sequence_parallel_degree > 1: + self._update_ring_flash_attn_params(inputs) + + # Proceed with normal prediction step + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore