diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7d0df8a45..9f7e1929b 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1043,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha + training_args_kwargs["sequence_parallel_degree"] = ( + self.cfg.sequence_parallel_degree + ) + training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl == "simpo": @@ -1161,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] + dpo_trainer = trainer_cls( *trainer_cls_args, args=training_args, @@ -1178,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer.add_callback(callback) return dpo_trainer - - -class HFPPOTrainerBuilder(TrainerBuilderBase): - """ - HF Factory class for PPO Trainer - """ - - def get_callbacks(self): - callbacks = super().get_callbacks() - return callbacks - - def get_post_trainer_create_callbacks(self, trainer): - callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) - return callbacks - - def build(self, total_num_steps): - # build PPOConfig - pass diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index a1cb819b6..710abf89c 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -8,10 +8,11 @@ import logging import os from collections import defaultdict from functools import wraps -from typing import Literal +from typing import Any, Literal import datasets import torch +import torch.nn as nn from datasets import Dataset from torch.utils.data import ( BatchSampler, @@ -25,6 +26,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from trl.trainer.utils import pad_to_length from typing_extensions import override +from axolotl.core.trainers.handlers import SequenceParallelHandler from axolotl.core.trainers.mixins import TrainerMixins from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -61,9 +63,7 @@ class AxolotlTrainer(TrainerMixins, Trainer): if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - # Initialize sequence parallelism if enabled - if self.args.sequence_parallel_degree > 1: - self._setup_sequence_parallel() + self.sequence_parallel_handler = SequenceParallelHandler(self.args) def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: @@ -124,7 +124,7 @@ class AxolotlTrainer(TrainerMixins, Trainer): # Determine the base sampler first if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_train_sampler(self.train_dataset) + base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset) elif self.args.curriculum_sampling: base_sampler = SequentialSampler(self.train_dataset) elif use_sample_packing: @@ -160,7 +160,7 @@ class AxolotlTrainer(TrainerMixins, Trainer): # Determine the base sampler if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_eval_sampler(eval_dataset) + base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset) elif use_multipack: base_sampler = SequentialSampler(eval_dataset) else: @@ -232,7 +232,10 @@ class AxolotlTrainer(TrainerMixins, Trainer): return dataloader # Otherwise prepare with accelerator - return self.accelerator.prepare_data_loader(dataloader) + dataloader = self.accelerator.prepare_data_loader(dataloader) + + return dataloader + def get_train_dataloader(self) -> DataLoader: """Get dataloader for training""" @@ -341,7 +344,57 @@ class AxolotlTrainer(TrainerMixins, Trainer): dataloader_params["drop_last"] = self.args.dataloader_drop_last return DataLoader(bench_dataset, **dataloader_params) - # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + + def training_step( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + num_items_in_batch: int | None = None, + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. Overrides the + `transformers.trainer.Trainer` method to handle sequence parallelism if + enabled. + + Args: + model: Model to perform training step for. + inputs: Dictionary mapping of inputs. + num_items_in_batch: The number of items in the batch. + """ + # Set up sequence parallelism for this step if enabled + if self.args.sequence_parallel_degree > 1: + self.sequence_parallel_handler._update_ring_flash_attn_params(inputs) + + # Proceed with normal training step + return super().training_step(model, inputs, num_items_in_batch) # type: ignore + + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """ + Perform a prediction step on a batch of inputs. Overrides the + `transformers.trainer.Trainer` method to handle sequence parallelism if + enabled. + + Args: + model: Model to perform prediction step for. + inputs: Dictionary mapping of inputs. + prediction_loss_only: Whether to return only the loss. + ignore_keys: Keys to ignore in the inputs. + + Returns: + Tuple of (loss, logits, labels). + """ + # Set up sequence parallelism for this prediction step if enabled + if self.args.sequence_parallel_degree > 1: + self.sequence_parallel_handler._update_ring_flash_attn_params(inputs) + + # Proceed with normal prediction step + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore @override def compute_loss( diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index a6b8f56ba..5f2f782fd 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -1,10 +1,7 @@ -""" -DPO trainer for axolotl -""" +"""DPO trainer for axolotl""" -import gc from functools import wraps -from typing import Any, Dict, Union +from typing import Any import torch from peft.optimizers import create_loraplus_optimizer @@ -13,6 +10,7 @@ from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer +from axolotl.core.trainers.handlers import SequenceParallelHandler from axolotl.core.trainers.mixins import TrainerMixins from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -24,17 +22,17 @@ if is_sagemaker_mp_enabled(): class AxolotlDPOTrainer(TrainerMixins, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ + """Extend the base DPOTrainer for axolotl helpers""" tag_names = ["axolotl", "dpo"] def __init__(self, *args, dataset_tags=None, **kwargs): super().__init__(*args, **kwargs) + self.dataset_tags = dataset_tags self.optimizer = None self.model_accepts_loss_kwargs = False + self.sequence_parallel_handler = SequenceParallelHandler(args=self.args) def create_optimizer(self): # pylint: disable=duplicate-code @@ -88,7 +86,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer): max_prompt_length, max_completion_length, add_special_tokens, - ) -> Dict: + ) -> dict: res = DPOTrainer.tokenize_row( features, processing_class, @@ -117,10 +115,9 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer): def training_step( self, model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], + inputs: dict[str, torch.Tensor | Any | None], num_items_in_batch=None, ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) - gc.collect() - torch.cuda.empty_cache() - return loss + self.sequence_parallel_handler.prepare_for_training_step(self, inputs) + + return super().training_step(model, inputs, num_items_in_batch) diff --git a/src/axolotl/core/trainers/handlers/__init__.py b/src/axolotl/core/trainers/handlers/__init__.py new file mode 100644 index 000000000..7b558f3d7 --- /dev/null +++ b/src/axolotl/core/trainers/handlers/__init__.py @@ -0,0 +1,3 @@ +"""Init for trainer handlers""" + +from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler diff --git a/src/axolotl/core/trainers/handlers/sequence_parallel.py b/src/axolotl/core/trainers/handlers/sequence_parallel.py new file mode 100644 index 000000000..f6629d7c1 --- /dev/null +++ b/src/axolotl/core/trainers/handlers/sequence_parallel.py @@ -0,0 +1,123 @@ +"""Handler class for sequence parallel trainer logic""" + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.utils.data import DistributedSampler + + +class SequenceParallelHandler: + """ + Handler class that encapsulates sequence parallelism functionality. + This replaces the SequenceParallelMixin with a composition-based approach. + """ + + def __init__(self, args=None): + """ + Initialize the sequence parallel handler. + + Args: + args: The arguments object containing sequence parallelism settings. + """ + self.args = args + self.ring_attn_group = None + + # Set up sequence parallelism if enabled + if self.args.sequence_parallel_degree > 1: + self._setup_sequence_parallel() + + def _setup_sequence_parallel(self): + """Set up sequence parallelism environment.""" + from ring_flash_attn import update_ring_flash_attn_params + from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + + self.update_ring_flash_attn_params = update_ring_flash_attn_params + self.ring_attn_group = get_ring_attn_group() + + def create_sequence_parallel_sampler( + self, + dataset, + shuffle=True, + is_eval=False, + ): + """ + Helper method to create sampler for sequence parallelism (SP). + + Args: + dataset: Dataset to sample from. + shuffle: Whether to shuffle the dataset. + is_eval: Whether we are creating a sampler for evaluation or training. + + Returns: + Distributed sampler. + """ + num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree + sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree + + return DistributedSampler( + dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + seed=self.args.seed if shuffle else None, + shuffle=shuffle, + drop_last=not is_eval, + ) + + def _get_train_sampler(self, dataset): + """ + Get a training sampler configured for sequence parallelism. + + Args: + dataset: The training dataset. + + Returns: + Configured sequence parallel sampler. + """ + return self.create_sequence_parallel_sampler( + dataset, + shuffle=not self.args.curriculum_sampling, + ) + + def _get_eval_sampler(self, eval_dataset): + """ + Get an evaluation sampler configured for sequence parallelism. + + Args: + eval_dataset: The evaluation dataset. + + Returns: + Configured sequence parallel sampler. + """ + return self.create_sequence_parallel_sampler( + eval_dataset, shuffle=False, is_eval=True + ) + + def _update_ring_flash_attn_params(self, inputs): + """ + Calculate the cu_seqlens for the current forward pass and pass the value to + the substituted ring_flash_attn. + + Args: + inputs: Current batch of inputs. + """ + # At this point, inputs should already be partitioned by the sequence + # parallel data collator + batch_size = inputs["input_ids"].shape[0] + seq_len = inputs["input_ids"].shape[1] + packed_seq_lens = [seq_len] * batch_size + + # Calculate the full sequence length across all GPUs in this SP group + total_seq_len = seq_len * self.args.sequence_parallel_degree + + cu_seqlens = torch.cumsum( + torch.tensor( + packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 + ), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad( + F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len + ) + + self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 052754a2f..8e9c36343 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,10 +6,9 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerMixin from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin -from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin class TrainerMixins( - OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin + OptimizerMixin, RngLoaderMixin, SchedulerMixin ): """Stub class combining all mixins for Axolotl trainers.""" diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 9bcd5db57..af3273f87 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,4 +1,5 @@ """Module for Axolotl trainer sequence parallelism mixin""" +# TODO(Dan): remove import logging from typing import Any @@ -7,7 +8,6 @@ import torch import torch.distributed as dist import torch.nn.functional as F from datasets import Dataset -from torch import nn from torch.utils.data import DistributedSampler, Sampler from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group @@ -71,12 +71,12 @@ class SequenceParallelMixin: drop_last=not is_eval, ) - def _sp_get_train_sampler(self, dataset) -> Sampler | None: + def _get_train_sampler(self, dataset) -> Sampler | None: """ Get a training sampler configured for sequence parallelism. Args: - dataset: The training dataset + dataset: The training dataset. Returns: Configured sequence parallel sampler. @@ -86,7 +86,7 @@ class SequenceParallelMixin: shuffle=not self.args.curriculum_sampling, ) - def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: + def _get_eval_sampler(self, eval_dataset) -> Sampler | None: """ Get an evaluation sampler configured for sequence parallelism. @@ -130,53 +130,3 @@ class SequenceParallelMixin: ) update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) - - def training_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - num_items_in_batch: int | None = None, - ) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform training step for. - inputs: Dictionary mapping. - """ - # Set up sequence parallelism for this step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal training step - return super().training_step(model, inputs, num_items_in_batch) # type: ignore - - def prediction_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - prediction_loss_only: bool, - ignore_keys: list[str] | None = None, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - """ - Perform a prediction step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform prediction step for. - inputs: Dictionary mapping of inputs. - prediction_loss_only: Whether to return only the loss. - ignore_keys: Keys to ignore in the inputs. - - Returns: - Tuple of (loss, logits, labels). - """ - # Set up sequence parallelism for this prediction step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal prediction step - return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 6c9d0b429..fd1e103e8 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -6,11 +6,22 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc their sequence parallel version of Flash Attention 2. """ +import torch import torch.distributed as dist +import torch.nn.functional as F from accelerate.logging import get_logger +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from axolotl.logging_config import configure_logging +try: + from ring_flash_attn import update_ring_flash_attn_params +except ImportError: + # We pass silently here, but raise an ImportError in our Axolotl config validation + # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. + pass + + configure_logging() LOG = get_logger(__name__) @@ -32,12 +43,120 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): Setter for ring attention group on this rank. Args: - Process group for ring attention. + ring_attn_group: Process group for ring attention. """ global RING_ATTN_GROUP # pylint: disable=global-statement RING_ATTN_GROUP = ring_attn_group +def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int): + """ + Patch flash attention a second time to handle batched data. This is a hack to + accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is + specified in the Axolotl config. + + Args: + sequence_parallel_degree: Sequence parallelism factor. + """ + # Store the original flash attention function + original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] + + def sequential_batch_flash_attention( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + dropout: float = 0.0, + scaling: float | None = None, + sliding_window: int | None = None, + softcap: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, None]: + # Check if we have a batch dimension > 1 + batch_size = query.shape[0] + + if batch_size <= 1: + return original_flash_attention( + module, + query, + key, + value, + attention_mask, + dropout, + scaling, + sliding_window, + softcap, + **kwargs + ) + + # Process each item in the batch separately + outputs = [] + + for i in range(batch_size): + # Extract single batch item + q_item = query[i:i+1] + k_item = key[i:i+1] + v_item = value[i:i+1] + + # Handle attention mask - it might be None or have different shapes + mask_item = None + if attention_mask is not None: + # The mask could have different formats depending on implementation + if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size: + mask_item = attention_mask[i:i+1] + else: + # For broadcast masks that don't have a batch dimension + mask_item = attention_mask + + # At this point, inputs should already be partitioned by the sequence + # parallel data collator + batch_size = q_item.shape[0] + seq_len = q_item.shape[2] + packed_seq_lens = [seq_len] * batch_size + + # Calculate the full sequence length across all GPUs in this SP group + total_seq_len = seq_len * sequence_parallel_degree + + cu_seqlens = torch.cumsum( + torch.tensor( + packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 + ), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad( + F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len + ) + + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) + + # Call the original function for a single batch item + output, _ = original_flash_attention( + module, + q_item, + k_item, + v_item, + mask_item, + dropout, + scaling, + sliding_window, + softcap, + **kwargs + ) + + outputs.append(output) + + dist.barrier() + + # Concatenate results along batch dimension + concatenated_output = torch.cat(outputs, dim=0) + return concatenated_output, None + + # Replace the original function with our sequential version + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention + + def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None): """ Create ring attention group and substitute flash attn with ring flash attn. @@ -98,3 +217,4 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None substitute_hf_flash_attn( process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride ) + patch_flash_attention_for_sequential_batch(sequence_parallel_degree) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9611ffca2..8165ddeb5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1351,9 +1351,7 @@ def load_model( reference_model: bool = False, **kwargs, # pylint: disable=unused-argument ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - """ - Load a model for a given configuration and tokenizer. - """ + """Load a model for a given configuration and tokenizer.""" model_loader = ModelLoader( cfg, tokenizer, @@ -1362,12 +1360,16 @@ def load_model( reference_model=reference_model, **kwargs, ) + return model_loader.load_model() -def load_adapter(model, cfg, adapter, inference=False): - # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - +def load_adapter( + model: PreTrainedModel, + cfg: DictDefault, + adapter: str | None, + inference: bool = False, +) -> tuple[PreTrainedModel, PeftConfig | None]: if adapter is None: return model, None if hasattr(model, "enable_input_require_grads"): @@ -1380,8 +1382,9 @@ def load_adapter(model, cfg, adapter, inference=False): raise NotImplementedError(f"{adapter} peft adapter not available") -def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] +def load_llama_adapter( + model: PreTrainedModel, cfg: DictDefault +) -> tuple[PreTrainedModel, PeftConfig | None]: from peft import AdaptionPromptConfig, get_peft_model peft_config = AdaptionPromptConfig( @@ -1405,7 +1408,7 @@ def load_llama_adapter(model, cfg): return model, peft_config -def find_all_linear_names(model): +def find_all_linear_names(model: PreTrainedModel): cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) lora_module_names = set() for name, module in model.named_modules():