working multi-group SP
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
"""Module for customized trainers."""
|
"""Module for customized trainers"""
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -361,9 +361,7 @@ class OptimizerMixin(Trainer):
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||||
"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
Extend the base Trainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -380,7 +378,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
@@ -425,6 +425,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
bin_size=self.args.sample_packing_bin_size,
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
return SequentialSampler(self.train_dataset)
|
return SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
@@ -455,9 +456,28 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""Get dataloader for training"""
|
||||||
train_dataset = self.train_dataset
|
train_dataset = self.train_dataset
|
||||||
data_collator = self.data_collator
|
data_collator = self.data_collator
|
||||||
|
|
||||||
|
# Handle dataset preprocessing
|
||||||
|
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||||
|
if (
|
||||||
|
self.args.sample_packing
|
||||||
|
and not self.args.pretraining
|
||||||
|
and "length" in train_dataset.features
|
||||||
|
):
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
if not (self.args.sample_packing and not self.args.pretraining):
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(
|
||||||
|
data_collator, description="training"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build common dataloader parameters
|
||||||
dataloader_params = {
|
dataloader_params = {
|
||||||
"batch_size": self._train_batch_size,
|
"batch_size": self._train_batch_size,
|
||||||
"collate_fn": data_collator,
|
"collate_fn": data_collator,
|
||||||
@@ -466,50 +486,67 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.dataloader_prefetch_factor:
|
||||||
if "length" in train_dataset.features.keys():
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Handle sequence parallelism (takes precedence over other sampling methods)
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
return self._get_sequence_parallel_dataloader(
|
||||||
|
train_dataset, dataloader_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle other sampler cases
|
||||||
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
sampler = self._get_train_sampler()
|
sampler = self._get_train_sampler()
|
||||||
|
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
|
if "batch_size" in dataloader_params:
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
|
# Create and possibly prepare the dataloader
|
||||||
|
dataloader = DataLoader(train_dataset, **dataloader_params)
|
||||||
|
|
||||||
|
# Sample packing with accelerator preparation
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
if self.args.sequence_parallel_size > 1:
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
return DataLoader(train_dataset, **dataloader_params)
|
|
||||||
|
|
||||||
return self.accelerator.prepare_data_loader(
|
return self.accelerator.prepare(dataloader)
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
def _get_sequence_parallel_dataloader(self, dataset, dataloader_params):
|
||||||
train_dataset = self._remove_unused_columns(
|
"""Create a dataloader with sequence-parallel-aware sampling."""
|
||||||
train_dataset, description="training"
|
# Calculate SP group information
|
||||||
)
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||||
|
global_rank = dist.get_rank()
|
||||||
|
sp_group_id = global_rank // self.args.sequence_parallel_size
|
||||||
|
|
||||||
|
# Create SP-group-aware sampler
|
||||||
|
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
|
sampler = None
|
||||||
else:
|
else:
|
||||||
data_collator = self._get_collator_with_removed_columns(
|
# Use different seeds for different SP groups to ensure they get different samples
|
||||||
data_collator, description="training"
|
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed,
|
||||||
|
drop_last=self.args.dataloader_drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
# Create dataloader without accelerator preparation
|
||||||
dataloader_params["sampler"] = self._get_train_sampler()
|
return DataLoader(
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataset,
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
sampler=sampler,
|
||||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
worker_init_fn=seed_worker,
|
||||||
|
**dataloader_params,
|
||||||
if self.args.sequence_parallel_size > 1:
|
)
|
||||||
return DataLoader(train_dataset, **dataloader_params)
|
|
||||||
|
|
||||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
@@ -584,6 +621,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -600,6 +638,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -843,16 +882,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
batch_size = inputs["input_ids"].shape[0]
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
|
||||||
# Get rank and SP information
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
world_size = (
|
|
||||||
dist.get_world_size(group=sp_group)
|
|
||||||
if sp_group
|
|
||||||
else dist.get_world_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
total_seq_len = seq_len * world_size
|
total_seq_len = seq_len * self.args.sequence_parallel_size
|
||||||
|
|
||||||
# Pass the partitioned sequence information to ring flash attention
|
# Pass the partitioned sequence information to ring flash attention
|
||||||
self._update_ring_flash_attn_params(
|
self._update_ring_flash_attn_params(
|
||||||
|
|||||||
@@ -60,11 +60,8 @@ def register_ring_attn(sequence_parallel_size: int):
|
|||||||
|
|
||||||
if rank in ring_attn_ranks:
|
if rank in ring_attn_ranks:
|
||||||
set_ring_attn_group(group)
|
set_ring_attn_group(group)
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank} assigned to sequence parallel group {i} with ranks {ring_attn_ranks}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log the full group assignment structure
|
# Log the GPU group assignments
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
@@ -165,6 +166,18 @@ def setup_signal_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_context_manager(enable=False) -> contextlib.AbstractContextManager:
|
||||||
|
"""Configure CUDA SDP kernel settings if enabled."""
|
||||||
|
if enable:
|
||||||
|
return torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=True,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
||||||
def execute_training(
|
def execute_training(
|
||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
@@ -177,18 +190,8 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
context_manager = train_context_manager(cfg.flash_optimum)
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
with context_manager:
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
|
||||||
with torch.backends.cuda.sdp_kernel(
|
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
|
||||||
enable_flash=True,
|
|
||||||
enable_math=True,
|
|
||||||
enable_mem_efficient=True,
|
|
||||||
):
|
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
||||||
else:
|
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -113,8 +113,10 @@ class DataCollatorForSeq2Seq:
|
|||||||
if self.sequence_parallel_size > 1:
|
if self.sequence_parallel_size > 1:
|
||||||
# Get information about our position in the SP group
|
# Get information about our position in the SP group
|
||||||
sp_group = get_ring_attn_group()
|
sp_group = get_ring_attn_group()
|
||||||
self.rank = dist.get_rank(group=sp_group)
|
self.rank = dist.get_rank()
|
||||||
self.world_size = dist.get_world_size(group=sp_group)
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -202,11 +204,11 @@ class DataCollatorForSeq2Seq:
|
|||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
for key in ["input_ids", "attention_mask", "labels"]:
|
||||||
if key in batch:
|
if key in batch:
|
||||||
seq_len = batch[key].shape[1]
|
seq_len = batch[key].shape[1]
|
||||||
slice_size = seq_len // self.world_size
|
slice_size = seq_len // self.local_world_size
|
||||||
start_idx = self.rank * slice_size
|
start_idx = self.local_rank * slice_size
|
||||||
end_idx = (
|
end_idx = (
|
||||||
start_idx + slice_size
|
start_idx + slice_size
|
||||||
if self.rank < self.world_size - 1
|
if self.local_rank < self.local_world_size - 1
|
||||||
else seq_len
|
else seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -214,45 +216,75 @@ class DataCollatorForSeq2Seq:
|
|||||||
# Before slicing
|
# Before slicing
|
||||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||||
|
f"Total sequence length: {seq_len}, "
|
||||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||||
)
|
)
|
||||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}"
|
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||||
|
f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||||
|
f"Slicing {key} from {seq_len} tokens to "
|
||||||
|
f"indices {start_idx}:{end_idx}"
|
||||||
)
|
)
|
||||||
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
if key == "input_ids":
|
# if key == "input_ids":
|
||||||
# After slicing
|
# # After slicing
|
||||||
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
# non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||||
logger.info(
|
# logger.info(
|
||||||
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||||
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
# f"Slice {start_idx}:{end_idx}, "
|
||||||
)
|
# f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
# )
|
||||||
|
# logger.info(
|
||||||
|
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||||
|
# f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
||||||
|
# )
|
||||||
|
|
||||||
dist.barrier()
|
# if key == "labels":
|
||||||
|
# min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100
|
||||||
|
# max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100
|
||||||
|
# logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}")
|
||||||
|
|
||||||
|
# # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index)
|
||||||
|
# invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100)
|
||||||
|
|
||||||
|
# if invalid_mask.any():
|
||||||
|
# # Log this for debugging
|
||||||
|
# num_invalid = invalid_mask.sum().item()
|
||||||
|
# logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100")
|
||||||
|
|
||||||
|
# # Replace invalid labels with -100 (ignore index)
|
||||||
|
# batch["labels"][invalid_mask] = -100
|
||||||
|
|
||||||
# Handle position_ids if present
|
# Handle position_ids if present
|
||||||
if "position_ids" in batch:
|
if "position_ids" in batch:
|
||||||
pos_ids = batch["position_ids"]
|
pos_ids = batch["position_ids"]
|
||||||
seq_len = pos_ids.shape[1]
|
seq_len = pos_ids.shape[1]
|
||||||
slice_size = seq_len // self.world_size
|
slice_size = seq_len // self.local_world_size
|
||||||
start_idx = self.rank * slice_size
|
start_idx = self.local_rank * slice_size
|
||||||
end_idx = (
|
end_idx = (
|
||||||
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
start_idx + slice_size
|
||||||
|
if self.local_rank < self.local_world_size - 1
|
||||||
|
else seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||||
|
|
||||||
# Adjust position_ids to be relative to the slice start
|
# Adjust position_ids to be relative to the slice start
|
||||||
if self.rank > 0:
|
if self.local_rank > 0:
|
||||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||||
batch["position_ids"], start_idx
|
batch["position_ids"], start_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if dist.get_rank() == 0:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# dist.barrier()
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user