diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4fe065671..5fe5b84a8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -80,10 +80,7 @@ from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, - SequenceParallelDataCollator, - SequenceParallelPackedDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, - V2SequenceParallelPackedDataCollator, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype @@ -871,13 +868,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator: Type[ Union[ V2BatchSamplerDataCollatorForSeq2Seq, - V2SequenceParallelPackedDataCollator, - SequenceParallelPackedDataCollator, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, DataCollatorWithFlattening, RewardDataCollatorWithPadding, - SequenceParallelDataCollator, ] ] collator_args = [self.tokenizer] @@ -890,15 +884,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.model_config_type in ["llama"] and self.cfg.flash_attention is not True ): - if self.cfg.sequence_parallel_size > 1: - collator = V2SequenceParallelPackedDataCollator - else: - collator = V2BatchSamplerDataCollatorForSeq2Seq + collator = V2BatchSamplerDataCollatorForSeq2Seq else: - if self.cfg.sequence_parallel_size > 1: - collator = SequenceParallelPackedDataCollator - else: - collator = BatchSamplerDataCollatorForSeq2Seq + collator = BatchSamplerDataCollatorForSeq2Seq else: if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator @@ -920,12 +908,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: collator = DataCollatorForKD else: - if self.cfg.sequence_parallel_size > 1: - collator = SequenceParallelDataCollator - else: - collator = DataCollatorForSeq2Seq + collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" + kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size return collator( *collator_args, diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index a5215114a..891f154ec 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -9,6 +9,7 @@ from collections import defaultdict from functools import wraps from typing import Any, Dict, Literal, Optional +import datasets import torch import torch.distributed as dist import torch.nn.functional as F @@ -20,7 +21,7 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import Trainer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker -from transformers.utils import is_sagemaker_mp_enabled +from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length from typing_extensions import override @@ -415,17 +416,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): generator = None if self.args.sequence_parallel_size > 1: generator = torch.Generator() - generator.manual_seed(self.args.getattr("seed", 0)) - - sampler = RandomSampler(self.train_dataset) - - # if dist.get_rank() == 0: - # import ipdb; ipdb.set_trace() - # dist.barrier() - - # if dist.get_rank() == 1: - # import ipdb; ipdb.set_trace() - # dist.barrier() + generator.manual_seed(self.args.seed) + sampler = RandomSampler( + self.train_dataset, generator=generator + ) return MultipackBatchSampler( sampler, @@ -443,7 +437,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): sampler = super()._get_train_sampler() if self.args.sequence_parallel_size > 1: generator = torch.Generator() - generator.manual_seed(self.args.getattr("seed", 0)) + generator.manual_seed(self.args.seed) sampler.generator = generator return sampler @@ -473,17 +467,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: + train_dataset = self.train_dataset + data_collator = self.data_collator + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + if self.args.sample_packing and not self.args.pretraining: - train_dataset = self.train_dataset if "length" in train_dataset.features.keys(): train_dataset = train_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } if self.args.dataloader_prefetch_factor: dataloader_params["prefetch_factor"] = ( self.args.dataloader_prefetch_factor @@ -498,17 +495,31 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker - if dist.get_rank() == 0: - import ipdb - - ipdb.set_trace() - dist.barrier() - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - return super().get_train_dataloader() + if self.args.sequence_parallel_size > 1: + return DataLoader(train_dataset, **dataloader_params) + else: + return self.accelerator.prepare_data_loader( + DataLoader(train_dataset, **dataloader_params) + ) + else: + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if self.args.sequence_parallel_size > 1: + return DataLoader(train_dataset, **dataloader_params) + else: + return self.accelerator.prepare( + DataLoader(train_dataset, **dataloader_params) + ) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: if self.args.sample_packing and self.args.eval_sample_packing is False: diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 66105d20d..8c60f223c 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -9,8 +9,3 @@ from .batching import ( # noqa: F401 V2BatchSamplerDataCollatorForSeq2Seq, ) from .mamba import MambaDataCollator # noqa: F401 -from .sequence_parallel import ( # noqa: F401 - SequenceParallelDataCollator, - SequenceParallelPackedDataCollator, - V2SequenceParallelPackedDataCollator, -) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 7cf771421..ee4679230 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,14 +1,66 @@ """ -DataCollator for axolotl to pad labels and position_ids for packed sequences +Data collators for axolotl to pad labels and position_ids for packed sequences. Also +includes logic for handling sequence parallelism collation. """ +import logging from dataclasses import dataclass from typing import Any, Optional, Union import numpy as np +import torch +import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + +logger = logging.getLogger(__name__) + + +def adjust_position_ids_for_slice( + position_ids: list | torch.Tensor, start_idx: int +) -> torch.Tensor: + """ + Adjust position IDs for a sliced sequence to maintain proper relative positions. + This handles the case where position IDs might not be contiguous due to sample packing. + """ + # Convert to tensor if not already + if not isinstance(position_ids, torch.Tensor): + position_ids = torch.tensor( + position_ids, + device=position_ids.device if hasattr(position_ids, "device") else None, + ) + + # Find the boundaries between samples (where position_ids reset) + adjusted_pos_ids = position_ids.clone() + + # Process each sequence in the batch + for i in range(position_ids.shape[0]): + seq = position_ids[i] + + # Find sample boundaries + boundaries = [] + for j in range(1, len(seq)): + if seq[j] < seq[j - 1]: + boundaries.append(j) + + # No need to adjust if there are no boundaries or this is a single sample + if not boundaries: + adjusted_pos_ids[i] = seq - start_idx + continue + + # Adjust each segment separately + prev_boundary = 0 + for boundary in boundaries: + adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx + prev_boundary = boundary + + # Last segment + adjusted_pos_ids[i, prev_boundary:] -= start_idx + + return adjusted_pos_ids + @dataclass class DataCollatorForSeq2Seq: @@ -43,6 +95,8 @@ class DataCollatorForSeq2Seq: The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". + sequence_parallel_size (`int`): + The degree of sequence parallelism. Default to 1 for no sequence parallelism. """ tokenizer: PreTrainedTokenizerBase @@ -53,6 +107,14 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" + sequence_parallel_size: int = 1 + + def __post_init__(self): + if self.sequence_parallel_size > 1: + # Get information about our position in the SP group + sp_group = get_ring_attn_group() + self.rank = dist.get_rank(group=sp_group) + self.world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None): labels = None @@ -119,8 +181,78 @@ class DataCollatorForSeq2Seq: ) features["decoder_input_ids"] = decoder_input_ids + if self.sequence_parallel_size > 1: + features = self.apply_sequence_parallelism(features) + return features + def apply_sequence_parallelism( + self, batch: dict[str, torch.Tensor] + ) -> torch.Tensor: + """ + Apply sequence parallelism slicing to a batch. + + Args: + batch: Batch dictionary from parent collator. + + Returns: + Sliced batch dictionary. + """ + # Process keys that need to be sliced + for key in ["input_ids", "attention_mask", "labels"]: + if key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // self.world_size + start_idx = self.rank * slice_size + end_idx = ( + start_idx + slice_size + if self.rank < self.world_size - 1 + else seq_len + ) + + if key == "input_ids": + # Before slicing + non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item() + logger.info( + f"GPU {self.rank}: Total sequence length: {seq_len}, " + f"Non-padding tokens: {non_pad_tokens_total}" + ) + logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") + logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}") + + batch[key] = batch[key][:, start_idx:end_idx] + + if key == "input_ids": + # After slicing + non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item() + logger.info( + f"GPU {self.rank}: Slice {start_idx}-{end_idx}, " + f"Non-padding tokens in slice: {non_pad_tokens_slice}" + ) + logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") + + dist.barrier() + + # Handle position_ids if present + if "position_ids" in batch: + pos_ids = batch["position_ids"] + seq_len = pos_ids.shape[1] + slice_size = seq_len // self.world_size + start_idx = self.rank * slice_size + end_idx = ( + start_idx + slice_size if self.rank < self.world_size - 1 else seq_len + ) + + batch["position_ids"] = pos_ids[:, start_idx:end_idx] + + # Adjust position_ids to be relative to the slice start + if self.rank > 0: + batch["position_ids"] = adjust_position_ids_for_slice( + batch["position_ids"], start_idx + ) + + return batch + @dataclass class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): @@ -148,6 +280,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): np.array(item[feature]) for item in features_ if feature in item ] out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) @@ -177,6 +310,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): np.array(item[feature]) for item in features_ if feature in item ] out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) diff --git a/src/axolotl/utils/collators/sequence_parallel.py b/src/axolotl/utils/collators/sequence_parallel.py deleted file mode 100644 index 1ed2957ca..000000000 --- a/src/axolotl/utils/collators/sequence_parallel.py +++ /dev/null @@ -1,177 +0,0 @@ -"""Module for sequence parallelism data collators.""" - -import logging -from dataclasses import dataclass - -import torch -import torch.distributed as dist - -from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group -from axolotl.utils.collators.batching import ( - BatchSamplerDataCollatorForSeq2Seq, - DataCollatorForSeq2Seq, - V2BatchSamplerDataCollatorForSeq2Seq, -) - -logger = logging.getLogger(__name__) - - -def adjust_position_ids_for_slice( - position_ids: list | torch.Tensor, start_idx: int -) -> torch.Tensor: - """ - Adjust position IDs for a sliced sequence to maintain proper relative positions. - This handles the case where position IDs might not be contiguous due to sample packing. - """ - # Convert to tensor if not already - if not isinstance(position_ids, torch.Tensor): - position_ids = torch.tensor( - position_ids, - device=position_ids.device if hasattr(position_ids, "device") else None, - ) - - # Find the boundaries between samples (where position_ids reset) - adjusted_pos_ids = position_ids.clone() - - # Process each sequence in the batch - for i in range(position_ids.shape[0]): - seq = position_ids[i] - - # Find sample boundaries - boundaries = [] - for j in range(1, len(seq)): - if seq[j] < seq[j - 1]: - boundaries.append(j) - - # No need to adjust if there are no boundaries or this is a single sample - if not boundaries: - adjusted_pos_ids[i] = seq - start_idx - continue - - # Adjust each segment separately - prev_boundary = 0 - for boundary in boundaries: - adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx - prev_boundary = boundary - - # Last segment - adjusted_pos_ids[i, prev_boundary:] -= start_idx - - return adjusted_pos_ids - - -class SequenceParallelMixin: - """ - Mixin to add sequence parallelism slicing to data collators. - """ - - def __post_init__(self): - # Get information about our position in the SP group - sp_group = get_ring_attn_group() - self.rank = dist.get_rank(group=sp_group) - self.world_size = dist.get_world_size(group=sp_group) - - def apply_sequence_parallelism( - self, batch: dict[str, torch.Tensor] - ) -> torch.Tensor: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary from parent collator. - - Returns: - Sliced batch dictionary. - """ - # Process keys that need to be sliced - for key in ["input_ids", "attention_mask", "labels"]: - if key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // self.world_size - start_idx = self.rank * slice_size - end_idx = ( - start_idx + slice_size - if self.rank < self.world_size - 1 - else seq_len - ) - - if key == "input_ids": - # Before slicing - non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item() - logger.info( - f"GPU {self.rank}: Total sequence length: {seq_len}, " - f"Non-padding tokens: {non_pad_tokens_total}" - ) - logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") - - # After slicing - non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item() - logger.info( - f"GPU {self.rank}: Slice {start_idx}-{end_idx}, " - f"Non-padding tokens in slice: {non_pad_tokens_slice}" - ) - - dist.barrier() - - batch[key] = batch[key][:, start_idx:end_idx] - - # Handle position_ids if present - if "position_ids" in batch: - pos_ids = batch["position_ids"] - seq_len = pos_ids.shape[1] - slice_size = seq_len // self.world_size - start_idx = self.rank * slice_size - end_idx = ( - start_idx + slice_size if self.rank < self.world_size - 1 else seq_len - ) - - batch["position_ids"] = pos_ids[:, start_idx:end_idx] - - # Adjust position_ids to be relative to the slice start - if self.rank > 0: - batch["position_ids"] = adjust_position_ids_for_slice( - batch["position_ids"], start_idx - ) - - return batch - - -@dataclass -class SequenceParallelPackedDataCollator( - SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq -): - """ - Data collator for sequence parallelism with sample packing. Combines multiple - samples into a packed sequence, then slices it for each GPU. - """ - - def __call__(self, features, return_tensors=None): - # Use the parent collator to handle sample packing and padding - batch = super().__call__(features, return_tensors=return_tensors) - return self.apply_sequence_parallelism(batch) - - -@dataclass -class V2SequenceParallelPackedDataCollator( - SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq -): - """ - Data collator for sequence parallelism with V2 sample packing. - """ - - def __call__(self, features, return_tensors=None): - # Use the parent collator to handle sample packing and padding - batch = super().__call__(features, return_tensors=return_tensors) - return self.apply_sequence_parallelism(batch) - - -@dataclass -class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq): - """ - Data collator for sequence parallelism without sample packing. - """ - - def __call__(self, features, return_tensors=None): - # Use the parent collator to pad everything correctly - batch = super().__call__(features, return_tensors=return_tensors) - return self.apply_sequence_parallelism(batch)