This commit is contained in:
Dan Saunders
2025-03-11 02:52:44 +00:00
parent 4190ad0647
commit 1d339e4007
5 changed files with 181 additions and 232 deletions

View File

@@ -80,10 +80,7 @@ from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
SequenceParallelDataCollator,
SequenceParallelPackedDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
V2SequenceParallelPackedDataCollator,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
@@ -871,13 +868,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator: Type[
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
V2SequenceParallelPackedDataCollator,
SequenceParallelPackedDataCollator,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
SequenceParallelDataCollator,
]
]
collator_args = [self.tokenizer]
@@ -890,15 +884,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
if self.cfg.sequence_parallel_size > 1:
collator = V2SequenceParallelPackedDataCollator
else:
collator = V2BatchSamplerDataCollatorForSeq2Seq
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.sequence_parallel_size > 1:
collator = SequenceParallelPackedDataCollator
else:
collator = BatchSamplerDataCollatorForSeq2Seq
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
@@ -920,12 +908,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
collator = DataCollatorForKD
else:
if self.cfg.sequence_parallel_size > 1:
collator = SequenceParallelDataCollator
else:
collator = DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size
return collator(
*collator_args,

View File

@@ -9,6 +9,7 @@ from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional
import datasets
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -20,7 +21,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -415,17 +416,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
generator = None
if self.args.sequence_parallel_size > 1:
generator = torch.Generator()
generator.manual_seed(self.args.getattr("seed", 0))
sampler = RandomSampler(self.train_dataset)
# if dist.get_rank() == 0:
# import ipdb; ipdb.set_trace()
# dist.barrier()
# if dist.get_rank() == 1:
# import ipdb; ipdb.set_trace()
# dist.barrier()
generator.manual_seed(self.args.seed)
sampler = RandomSampler(
self.train_dataset, generator=generator
)
return MultipackBatchSampler(
sampler,
@@ -443,7 +437,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
sampler = super()._get_train_sampler()
if self.args.sequence_parallel_size > 1:
generator = torch.Generator()
generator.manual_seed(self.args.getattr("seed", 0))
generator.manual_seed(self.args.seed)
sampler.generator = generator
return sampler
@@ -473,17 +467,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
train_dataset = self.train_dataset
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
@@ -498,17 +495,31 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
else:
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
else:
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
else:
return self.accelerator.prepare(
DataLoader(train_dataset, **dataloader_params)
)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:

View File

@@ -9,8 +9,3 @@ from .batching import ( # noqa: F401
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator # noqa: F401
from .sequence_parallel import ( # noqa: F401
SequenceParallelDataCollator,
SequenceParallelPackedDataCollator,
V2SequenceParallelPackedDataCollator,
)

View File

@@ -1,14 +1,66 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: list | torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample packing.
"""
# Convert to tensor if not already
if not isinstance(position_ids, torch.Tensor):
position_ids = torch.tensor(
position_ids,
device=position_ids.device if hasattr(position_ids, "device") else None,
)
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()
# Process each sequence in the batch
for i in range(position_ids.shape[0]):
seq = position_ids[i]
# Find sample boundaries
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
# No need to adjust if there are no boundaries or this is a single sample
if not boundaries:
adjusted_pos_ids[i] = seq - start_idx
continue
# Adjust each segment separately
prev_boundary = 0
for boundary in boundaries:
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
prev_boundary = boundary
# Last segment
adjusted_pos_ids[i, prev_boundary:] -= start_idx
return adjusted_pos_ids
@dataclass
class DataCollatorForSeq2Seq:
@@ -43,6 +95,8 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_size (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
@@ -53,6 +107,14 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_size: int = 1
def __post_init__(self):
if self.sequence_parallel_size > 1:
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.rank = dist.get_rank(group=sp_group)
self.world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
labels = None
@@ -119,8 +181,78 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_size > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Process keys that need to be sliced
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.world_size
start_idx = self.rank * slice_size
end_idx = (
start_idx + slice_size
if self.rank < self.world_size - 1
else seq_len
)
if key == "input_ids":
# Before slicing
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}: Total sequence length: {seq_len}, "
f"Non-padding tokens: {non_pad_tokens_total}"
)
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}")
batch[key] = batch[key][:, start_idx:end_idx]
if key == "input_ids":
# After slicing
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
)
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
dist.barrier()
# Handle position_ids if present
if "position_ids" in batch:
pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1]
slice_size = seq_len // self.world_size
start_idx = self.rank * slice_size
end_idx = (
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
# Adjust position_ids to be relative to the slice start
if self.rank > 0:
batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@@ -148,6 +280,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
@@ -177,6 +310,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -1,177 +0,0 @@
"""Module for sequence parallelism data collators."""
import logging
from dataclasses import dataclass
import torch
import torch.distributed as dist
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.utils.collators.batching import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: list | torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample packing.
"""
# Convert to tensor if not already
if not isinstance(position_ids, torch.Tensor):
position_ids = torch.tensor(
position_ids,
device=position_ids.device if hasattr(position_ids, "device") else None,
)
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()
# Process each sequence in the batch
for i in range(position_ids.shape[0]):
seq = position_ids[i]
# Find sample boundaries
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
# No need to adjust if there are no boundaries or this is a single sample
if not boundaries:
adjusted_pos_ids[i] = seq - start_idx
continue
# Adjust each segment separately
prev_boundary = 0
for boundary in boundaries:
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
prev_boundary = boundary
# Last segment
adjusted_pos_ids[i, prev_boundary:] -= start_idx
return adjusted_pos_ids
class SequenceParallelMixin:
"""
Mixin to add sequence parallelism slicing to data collators.
"""
def __post_init__(self):
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.rank = dist.get_rank(group=sp_group)
self.world_size = dist.get_world_size(group=sp_group)
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Process keys that need to be sliced
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.world_size
start_idx = self.rank * slice_size
end_idx = (
start_idx + slice_size
if self.rank < self.world_size - 1
else seq_len
)
if key == "input_ids":
# Before slicing
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}: Total sequence length: {seq_len}, "
f"Non-padding tokens: {non_pad_tokens_total}"
)
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
# After slicing
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
)
dist.barrier()
batch[key] = batch[key][:, start_idx:end_idx]
# Handle position_ids if present
if "position_ids" in batch:
pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1]
slice_size = seq_len // self.world_size
start_idx = self.rank * slice_size
end_idx = (
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
# Adjust position_ids to be relative to the slice start
if self.rank > 0:
batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
return batch
@dataclass
class SequenceParallelPackedDataCollator(
SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq
):
"""
Data collator for sequence parallelism with sample packing. Combines multiple
samples into a packed sequence, then slices it for each GPU.
"""
def __call__(self, features, return_tensors=None):
# Use the parent collator to handle sample packing and padding
batch = super().__call__(features, return_tensors=return_tensors)
return self.apply_sequence_parallelism(batch)
@dataclass
class V2SequenceParallelPackedDataCollator(
SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq
):
"""
Data collator for sequence parallelism with V2 sample packing.
"""
def __call__(self, features, return_tensors=None):
# Use the parent collator to handle sample packing and padding
batch = super().__call__(features, return_tensors=return_tensors)
return self.apply_sequence_parallelism(batch)
@dataclass
class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq):
"""
Data collator for sequence parallelism without sample packing.
"""
def __call__(self, features, return_tensors=None):
# Use the parent collator to pad everything correctly
batch = super().__call__(features, return_tensors=return_tensors)
return self.apply_sequence_parallelism(batch)