From b7738d57c4a54427bd4ea16adda664c2b1d207b6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 12 Mar 2025 19:33:40 +0000 Subject: [PATCH] working multi-group SP --- src/axolotl/core/trainers/base.py | 117 +++++++++++------- .../monkeypatch/attention/ring_attn.py | 5 +- src/axolotl/train.py | 27 ++-- src/axolotl/utils/collators/batching.py | 74 +++++++---- 4 files changed, 143 insertions(+), 80 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 0b969c3aa..83106868c 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -1,8 +1,8 @@ -"""Module for customized trainers.""" +"""Module for customized trainers""" +# pylint: disable=too-many-lines from __future__ import annotations -# pylint: disable=too-many-lines import logging import os from collections import defaultdict @@ -361,9 +361,7 @@ class OptimizerMixin(Trainer): class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): - """ - Extend the base Trainer for axolotl helpers - """ + """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] tag_names = ["axolotl"] @@ -380,7 +378,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): self.eval_data_collator = eval_data_collator self.dataset_tags = dataset_tags self._signature_columns = None # workaround for pylint + super().__init__(*_args, **kwargs) + self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) if self.args.orpo_alpha: @@ -425,6 +425,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): bin_size=self.args.sample_packing_bin_size, drop_last=True, ) + if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) @@ -455,9 +456,28 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: + """Get dataloader for training""" train_dataset = self.train_dataset data_collator = self.data_collator + # Handle dataset preprocessing + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + if ( + self.args.sample_packing + and not self.args.pretraining + and "length" in train_dataset.features + ): + train_dataset = train_dataset.remove_columns(["length"]) + if not (self.args.sample_packing and not self.args.pretraining): + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + data_collator = self._get_collator_with_removed_columns( + data_collator, description="training" + ) + + # Build common dataloader parameters dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, @@ -466,50 +486,67 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): "persistent_workers": self.args.dataloader_persistent_workers, } - if self.args.sample_packing and not self.args.pretraining: - if "length" in train_dataset.features.keys(): - train_dataset = train_dataset.remove_columns(["length"]) - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = ( - self.args.dataloader_prefetch_factor - ) + if self.args.dataloader_prefetch_factor: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + # Handle sequence parallelism (takes precedence over other sampling methods) + if self.args.sequence_parallel_size > 1: + return self._get_sequence_parallel_dataloader( + train_dataset, dataloader_params + ) + + # Handle other sampler cases + if not isinstance(train_dataset, torch.utils.data.IterableDataset): sampler = self._get_train_sampler() + if isinstance(sampler, BatchSampler): dataloader_params["batch_sampler"] = sampler - del dataloader_params["batch_size"] + # batch_size and batch_sampler are mutually exclusive + if "batch_size" in dataloader_params: + del dataloader_params["batch_size"] else: dataloader_params["sampler"] = sampler dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + # Create and possibly prepare the dataloader + dataloader = DataLoader(train_dataset, **dataloader_params) + + # Sample packing with accelerator preparation + if self.args.sample_packing and not self.args.pretraining: self.accelerator.even_batches = False - if self.args.sequence_parallel_size > 1: - return DataLoader(train_dataset, **dataloader_params) + return self.accelerator.prepare_data_loader(dataloader) - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) + return self.accelerator.prepare(dataloader) - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns( - train_dataset, description="training" - ) + def _get_sequence_parallel_dataloader(self, dataset, dataloader_params): + """Create a dataloader with sequence-parallel-aware sampling.""" + # Calculate SP group information + num_sp_groups = self.args.world_size // self.args.sequence_parallel_size + global_rank = dist.get_rank() + sp_group_id = global_rank // self.args.sequence_parallel_size + + # Create SP-group-aware sampler + if isinstance(dataset, torch.utils.data.IterableDataset): + sampler = None else: - data_collator = self._get_collator_with_removed_columns( - data_collator, description="training" + # Use different seeds for different SP groups to ensure they get different samples + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + seed=self.args.seed, + drop_last=self.args.dataloader_drop_last, ) - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - if self.args.sequence_parallel_size > 1: - return DataLoader(train_dataset, **dataloader_params) - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + # Create dataloader without accelerator preparation + return DataLoader( + dataset, + sampler=sampler, + worker_init_fn=seed_worker, + **dataloader_params, + ) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: if self.args.sample_packing and self.args.eval_sample_packing is False: @@ -584,6 +621,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return DataLoader(bench_dataset, **dataloader_params) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + @override def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): @@ -600,6 +638,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) + return super().compute_loss( model, inputs, @@ -843,16 +882,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): batch_size = inputs["input_ids"].shape[0] seq_len = inputs["input_ids"].shape[1] - # Get rank and SP information - sp_group = get_ring_attn_group() - world_size = ( - dist.get_world_size(group=sp_group) - if sp_group - else dist.get_world_size() - ) - # Calculate the full sequence length across all GPUs in this SP group - total_seq_len = seq_len * world_size + total_seq_len = seq_len * self.args.sequence_parallel_size # Pass the partitioned sequence information to ring flash attention self._update_ring_flash_attn_params( diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index d6e245820..e45ca249f 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -60,11 +60,8 @@ def register_ring_attn(sequence_parallel_size: int): if rank in ring_attn_ranks: set_ring_attn_group(group) - LOG.info( - f"GPU {rank} assigned to sequence parallel group {i} with ranks {ring_attn_ranks}" - ) - # Log the full group assignment structure + # Log the GPU group assignments if rank == 0: LOG.info(f"Sequence parallel group assignments: {group_assignments}") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index ff486db29..b4dcd326a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,5 +1,6 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import contextlib import importlib import inspect import os @@ -165,6 +166,18 @@ def setup_signal_handler( ) +def train_context_manager(enable=False) -> contextlib.AbstractContextManager: + """Configure CUDA SDP kernel settings if enabled.""" + if enable: + return torch.backends.cuda.sdp_kernel( + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, + ) + + return contextlib.nullcontext() + + def execute_training( cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None ): @@ -177,18 +190,8 @@ def execute_training( resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ LOG.info("Starting trainer...") - if cfg.group_by_length: - LOG.info("hang tight... sorting dataset for group_by_length") - - if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel( - # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... - enable_flash=True, - enable_math=True, - enable_mem_efficient=True, - ): - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - else: + context_manager = train_context_manager(cfg.flash_optimum) + with context_manager: trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index eaf47c751..885338651 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -113,8 +113,10 @@ class DataCollatorForSeq2Seq: if self.sequence_parallel_size > 1: # Get information about our position in the SP group sp_group = get_ring_attn_group() - self.rank = dist.get_rank(group=sp_group) - self.world_size = dist.get_world_size(group=sp_group) + self.rank = dist.get_rank() + self.local_rank = dist.get_rank(group=sp_group) + self.world_size = dist.get_world_size() + self.local_world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None): labels = None @@ -202,11 +204,11 @@ class DataCollatorForSeq2Seq: for key in ["input_ids", "attention_mask", "labels"]: if key in batch: seq_len = batch[key].shape[1] - slice_size = seq_len // self.world_size - start_idx = self.rank * slice_size + slice_size = seq_len // self.local_world_size + start_idx = self.local_rank * slice_size end_idx = ( start_idx + slice_size - if self.rank < self.world_size - 1 + if self.local_rank < self.local_world_size - 1 else seq_len ) @@ -214,45 +216,75 @@ class DataCollatorForSeq2Seq: # Before slicing non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item() logger.info( - f"GPU {self.rank}: Total sequence length: {seq_len}, " + f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " + f"Total sequence length: {seq_len}, " f"Non-padding tokens: {non_pad_tokens_total}" ) - logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") logger.info( - f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}" + f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " + f"GPU {self.rank} token IDs: {batch['input_ids']}" + ) + logger.info( + f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: " + f"Slicing {key} from {seq_len} tokens to " + f"indices {start_idx}:{end_idx}" ) batch[key] = batch[key][:, start_idx:end_idx] - if key == "input_ids": - # After slicing - non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item() - logger.info( - f"GPU {self.rank}: Slice {start_idx}-{end_idx}, " - f"Non-padding tokens in slice: {non_pad_tokens_slice}" - ) - logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") + # if key == "input_ids": + # # After slicing + # non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item() + # logger.info( + # f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " + # f"Slice {start_idx}:{end_idx}, " + # f"Non-padding tokens in slice: {non_pad_tokens_slice}" + # ) + # logger.info( + # f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " + # f"GPU {self.rank} token IDs: {batch['input_ids']}" + # ) - dist.barrier() + # if key == "labels": + # min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100 + # max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100 + # logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}") + + # # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index) + # invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100) + + # if invalid_mask.any(): + # # Log this for debugging + # num_invalid = invalid_mask.sum().item() + # logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100") + + # # Replace invalid labels with -100 (ignore index) + # batch["labels"][invalid_mask] = -100 # Handle position_ids if present if "position_ids" in batch: pos_ids = batch["position_ids"] seq_len = pos_ids.shape[1] - slice_size = seq_len // self.world_size - start_idx = self.rank * slice_size + slice_size = seq_len // self.local_world_size + start_idx = self.local_rank * slice_size end_idx = ( - start_idx + slice_size if self.rank < self.world_size - 1 else seq_len + start_idx + slice_size + if self.local_rank < self.local_world_size - 1 + else seq_len ) batch["position_ids"] = pos_ids[:, start_idx:end_idx] # Adjust position_ids to be relative to the slice start - if self.rank > 0: + if self.local_rank > 0: batch["position_ids"] = adjust_position_ids_for_slice( batch["position_ids"], start_idx ) + # if dist.get_rank() == 0: + # import ipdb; ipdb.set_trace() + # dist.barrier() + return batch