working multi-group SP

This commit is contained in:
Dan Saunders
2025-03-12 19:33:40 +00:00
parent 698e599bf7
commit b7738d57c4
4 changed files with 143 additions and 80 deletions

View File

@@ -1,8 +1,8 @@
"""Module for customized trainers."""
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
# pylint: disable=too-many-lines
import logging
import os
from collections import defaultdict
@@ -361,9 +361,7 @@ class OptimizerMixin(Trainer):
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
@@ -380,7 +378,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -425,6 +425,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
@@ -455,9 +456,28 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator
# Handle dataset preprocessing
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
if (
self.args.sample_packing
and not self.args.pretraining
and "length" in train_dataset.features
):
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
# Build common dataloader parameters
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
@@ -466,50 +486,67 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"persistent_workers": self.args.dataloader_persistent_workers,
}
if self.args.sample_packing and not self.args.pretraining:
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
# Handle sequence parallelism (takes precedence over other sampling methods)
if self.args.sequence_parallel_size > 1:
return self._get_sequence_parallel_dataloader(
train_dataset, dataloader_params
)
# Handle other sampler cases
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
# batch_size and batch_sampler are mutually exclusive
if "batch_size" in dataloader_params:
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
# Create and possibly prepare the dataloader
dataloader = DataLoader(train_dataset, **dataloader_params)
# Sample packing with accelerator preparation
if self.args.sample_packing and not self.args.pretraining:
self.accelerator.even_batches = False
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
return self.accelerator.prepare_data_loader(dataloader)
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return self.accelerator.prepare(dataloader)
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
def _get_sequence_parallel_dataloader(self, dataset, dataloader_params):
"""Create a dataloader with sequence-parallel-aware sampling."""
# Calculate SP group information
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
global_rank = dist.get_rank()
sp_group_id = global_rank // self.args.sequence_parallel_size
# Create SP-group-aware sampler
if isinstance(dataset, torch.utils.data.IterableDataset):
sampler = None
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
# Use different seeds for different SP groups to ensure they get different samples
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed,
drop_last=self.args.dataloader_drop_last,
)
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
# Create dataloader without accelerator preparation
return DataLoader(
dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**dataloader_params,
)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
@@ -584,6 +621,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
@@ -600,6 +638,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
@@ -843,16 +882,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
# Get rank and SP information
sp_group = get_ring_attn_group()
world_size = (
dist.get_world_size(group=sp_group)
if sp_group
else dist.get_world_size()
)
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * world_size
total_seq_len = seq_len * self.args.sequence_parallel_size
# Pass the partitioned sequence information to ring flash attention
self._update_ring_flash_attn_params(

View File

@@ -60,11 +60,8 @@ def register_ring_attn(sequence_parallel_size: int):
if rank in ring_attn_ranks:
set_ring_attn_group(group)
LOG.info(
f"GPU {rank} assigned to sequence parallel group {i} with ranks {ring_attn_ranks}"
)
# Log the full group assignment structure
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import contextlib
import importlib
import inspect
import os
@@ -165,6 +166,18 @@ def setup_signal_handler(
)
def train_context_manager(enable=False) -> contextlib.AbstractContextManager:
"""Configure CUDA SDP kernel settings if enabled."""
if enable:
return torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
return contextlib.nullcontext()
def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
@@ -177,18 +190,8 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
context_manager = train_context_manager(cfg.flash_optimum)
with context_manager:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -113,8 +113,10 @@ class DataCollatorForSeq2Seq:
if self.sequence_parallel_size > 1:
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.rank = dist.get_rank(group=sp_group)
self.world_size = dist.get_world_size(group=sp_group)
self.rank = dist.get_rank()
self.local_rank = dist.get_rank(group=sp_group)
self.world_size = dist.get_world_size()
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
labels = None
@@ -202,11 +204,11 @@ class DataCollatorForSeq2Seq:
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.world_size
start_idx = self.rank * slice_size
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size
if self.rank < self.world_size - 1
if self.local_rank < self.local_world_size - 1
else seq_len
)
@@ -214,45 +216,75 @@ class DataCollatorForSeq2Seq:
# Before slicing
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}: Total sequence length: {seq_len}, "
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Total sequence length: {seq_len}, "
f"Non-padding tokens: {non_pad_tokens_total}"
)
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
logger.info(
f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}"
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"GPU {self.rank} token IDs: {batch['input_ids']}"
)
logger.info(
f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Slicing {key} from {seq_len} tokens to "
f"indices {start_idx}:{end_idx}"
)
batch[key] = batch[key][:, start_idx:end_idx]
if key == "input_ids":
# After slicing
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
)
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
# if key == "input_ids":
# # After slicing
# non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
# logger.info(
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
# f"Slice {start_idx}:{end_idx}, "
# f"Non-padding tokens in slice: {non_pad_tokens_slice}"
# )
# logger.info(
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
# f"GPU {self.rank} token IDs: {batch['input_ids']}"
# )
dist.barrier()
# if key == "labels":
# min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100
# max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100
# logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}")
# # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index)
# invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100)
# if invalid_mask.any():
# # Log this for debugging
# num_invalid = invalid_mask.sum().item()
# logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100")
# # Replace invalid labels with -100 (ignore index)
# batch["labels"][invalid_mask] = -100
# Handle position_ids if present
if "position_ids" in batch:
pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1]
slice_size = seq_len // self.world_size
start_idx = self.rank * slice_size
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
start_idx + slice_size
if self.local_rank < self.local_world_size - 1
else seq_len
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
# Adjust position_ids to be relative to the slice start
if self.rank > 0:
if self.local_rank > 0:
batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
# if dist.get_rank() == 0:
# import ipdb; ipdb.set_trace()
# dist.barrier()
return batch