working multi-group SP
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""Module for customized trainers."""
|
||||
"""Module for customized trainers"""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
@@ -361,9 +361,7 @@ class OptimizerMixin(Trainer):
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
@@ -380,7 +378,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
@@ -425,6 +425,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
|
||||
@@ -455,9 +456,28 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and not self.args.pretraining
|
||||
and "length" in train_dataset.features
|
||||
):
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not (self.args.sample_packing and not self.args.pretraining):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
# Build common dataloader parameters
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
@@ -466,50 +486,67 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
# Handle sequence parallelism (takes precedence over other sampling methods)
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return self._get_sequence_parallel_dataloader(
|
||||
train_dataset, dataloader_params
|
||||
)
|
||||
|
||||
# Handle other sampler cases
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = self._get_train_sampler()
|
||||
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
# batch_size and batch_sampler are mutually exclusive
|
||||
if "batch_size" in dataloader_params:
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create and possibly prepare the dataloader
|
||||
dataloader = DataLoader(train_dataset, **dataloader_params)
|
||||
|
||||
# Sample packing with accelerator preparation
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
self.accelerator.even_batches = False
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return self.accelerator.prepare(dataloader)
|
||||
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
def _get_sequence_parallel_dataloader(self, dataset, dataloader_params):
|
||||
"""Create a dataloader with sequence-parallel-aware sampling."""
|
||||
# Calculate SP group information
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||
global_rank = dist.get_rank()
|
||||
sp_group_id = global_rank // self.args.sequence_parallel_size
|
||||
|
||||
# Create SP-group-aware sampler
|
||||
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
sampler = None
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
# Use different seeds for different SP groups to ensure they get different samples
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
# Create dataloader without accelerator preparation
|
||||
return DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
**dataloader_params,
|
||||
)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
@@ -584,6 +621,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
@@ -600,6 +638,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
@@ -843,16 +882,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
|
||||
# Get rank and SP information
|
||||
sp_group = get_ring_attn_group()
|
||||
world_size = (
|
||||
dist.get_world_size(group=sp_group)
|
||||
if sp_group
|
||||
else dist.get_world_size()
|
||||
)
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * world_size
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_size
|
||||
|
||||
# Pass the partitioned sequence information to ring flash attention
|
||||
self._update_ring_flash_attn_params(
|
||||
|
||||
@@ -60,11 +60,8 @@ def register_ring_attn(sequence_parallel_size: int):
|
||||
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
LOG.info(
|
||||
f"GPU {rank} assigned to sequence parallel group {i} with ranks {ring_attn_ranks}"
|
||||
)
|
||||
|
||||
# Log the full group assignment structure
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
@@ -165,6 +166,18 @@ def setup_signal_handler(
|
||||
)
|
||||
|
||||
|
||||
def train_context_manager(enable=False) -> contextlib.AbstractContextManager:
|
||||
"""Configure CUDA SDP kernel settings if enabled."""
|
||||
if enable:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
)
|
||||
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
@@ -177,18 +190,8 @@ def execute_training(
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
context_manager = train_context_manager(cfg.flash_optimum)
|
||||
with context_manager:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
|
||||
|
||||
@@ -113,8 +113,10 @@ class DataCollatorForSeq2Seq:
|
||||
if self.sequence_parallel_size > 1:
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank(group=sp_group)
|
||||
self.world_size = dist.get_world_size(group=sp_group)
|
||||
self.rank = dist.get_rank()
|
||||
self.local_rank = dist.get_rank(group=sp_group)
|
||||
self.world_size = dist.get_world_size()
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
@@ -202,11 +204,11 @@ class DataCollatorForSeq2Seq:
|
||||
for key in ["input_ids", "attention_mask", "labels"]:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.world_size
|
||||
start_idx = self.rank * slice_size
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.rank < self.world_size - 1
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
|
||||
@@ -214,45 +216,75 @@ class DataCollatorForSeq2Seq:
|
||||
# Before slicing
|
||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Total sequence length: {seq_len}, "
|
||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||
)
|
||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||
logger.info(
|
||||
f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}"
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
||||
)
|
||||
logger.info(
|
||||
f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Slicing {key} from {seq_len} tokens to "
|
||||
f"indices {start_idx}:{end_idx}"
|
||||
)
|
||||
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
if key == "input_ids":
|
||||
# After slicing
|
||||
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
||||
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||
)
|
||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||
# if key == "input_ids":
|
||||
# # After slicing
|
||||
# non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||
# logger.info(
|
||||
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
# f"Slice {start_idx}:{end_idx}, "
|
||||
# f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||
# )
|
||||
# logger.info(
|
||||
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
# f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
||||
# )
|
||||
|
||||
dist.barrier()
|
||||
# if key == "labels":
|
||||
# min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100
|
||||
# max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100
|
||||
# logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}")
|
||||
|
||||
# # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index)
|
||||
# invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100)
|
||||
|
||||
# if invalid_mask.any():
|
||||
# # Log this for debugging
|
||||
# num_invalid = invalid_mask.sum().item()
|
||||
# logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100")
|
||||
|
||||
# # Replace invalid labels with -100 (ignore index)
|
||||
# batch["labels"][invalid_mask] = -100
|
||||
|
||||
# Handle position_ids if present
|
||||
if "position_ids" in batch:
|
||||
pos_ids = batch["position_ids"]
|
||||
seq_len = pos_ids.shape[1]
|
||||
slice_size = seq_len // self.world_size
|
||||
start_idx = self.rank * slice_size
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
||||
start_idx + slice_size
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
|
||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||
|
||||
# Adjust position_ids to be relative to the slice start
|
||||
if self.rank > 0:
|
||||
if self.local_rank > 0:
|
||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||
batch["position_ids"], start_idx
|
||||
)
|
||||
|
||||
# if dist.get_rank() == 0:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# dist.barrier()
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user