fixes
This commit is contained in:
@@ -80,10 +80,7 @@ from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
SequenceParallelDataCollator,
|
||||
SequenceParallelPackedDataCollator,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
V2SequenceParallelPackedDataCollator,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
@@ -871,13 +868,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator: Type[
|
||||
Union[
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
V2SequenceParallelPackedDataCollator,
|
||||
SequenceParallelPackedDataCollator,
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorWithFlattening,
|
||||
RewardDataCollatorWithPadding,
|
||||
SequenceParallelDataCollator,
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
@@ -890,15 +884,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
collator = V2SequenceParallelPackedDataCollator
|
||||
else:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
collator = SequenceParallelPackedDataCollator
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
if self.cfg.processor_type and self.processor:
|
||||
collator = MultiModalChatDataCollator
|
||||
@@ -920,12 +908,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
else:
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
collator = SequenceParallelDataCollator
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
@@ -20,7 +21,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
@@ -415,17 +416,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
generator = None
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(self.args.getattr("seed", 0))
|
||||
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
|
||||
# if dist.get_rank() == 0:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# dist.barrier()
|
||||
|
||||
# if dist.get_rank() == 1:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# dist.barrier()
|
||||
generator.manual_seed(self.args.seed)
|
||||
sampler = RandomSampler(
|
||||
self.train_dataset, generator=generator
|
||||
)
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
@@ -443,7 +437,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
sampler = super()._get_train_sampler()
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(self.args.getattr("seed", 0))
|
||||
generator.manual_seed(self.args.seed)
|
||||
sampler.generator = generator
|
||||
|
||||
return sampler
|
||||
@@ -473,17 +467,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
@@ -498,17 +495,31 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
dist.barrier()
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
else:
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
else:
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return DataLoader(train_dataset, **dataloader_params)
|
||||
else:
|
||||
return self.accelerator.prepare(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
|
||||
@@ -9,8 +9,3 @@ from .batching import ( # noqa: F401
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from .mamba import MambaDataCollator # noqa: F401
|
||||
from .sequence_parallel import ( # noqa: F401
|
||||
SequenceParallelDataCollator,
|
||||
SequenceParallelPackedDataCollator,
|
||||
V2SequenceParallelPackedDataCollator,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,66 @@
|
||||
"""
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||
includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: list | torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(position_ids, torch.Tensor):
|
||||
position_ids = torch.tensor(
|
||||
position_ids,
|
||||
device=position_ids.device if hasattr(position_ids, "device") else None,
|
||||
)
|
||||
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
# Process each sequence in the batch
|
||||
for i in range(position_ids.shape[0]):
|
||||
seq = position_ids[i]
|
||||
|
||||
# Find sample boundaries
|
||||
boundaries = []
|
||||
for j in range(1, len(seq)):
|
||||
if seq[j] < seq[j - 1]:
|
||||
boundaries.append(j)
|
||||
|
||||
# No need to adjust if there are no boundaries or this is a single sample
|
||||
if not boundaries:
|
||||
adjusted_pos_ids[i] = seq - start_idx
|
||||
continue
|
||||
|
||||
# Adjust each segment separately
|
||||
prev_boundary = 0
|
||||
for boundary in boundaries:
|
||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||
prev_boundary = boundary
|
||||
|
||||
# Last segment
|
||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||
|
||||
return adjusted_pos_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
@@ -43,6 +95,8 @@ class DataCollatorForSeq2Seq:
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
sequence_parallel_size (`int`):
|
||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
@@ -53,6 +107,14 @@ class DataCollatorForSeq2Seq:
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_size: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_size > 1:
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank(group=sp_group)
|
||||
self.world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
@@ -119,8 +181,78 @@ class DataCollatorForSeq2Seq:
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
if self.sequence_parallel_size > 1:
|
||||
features = self.apply_sequence_parallelism(features)
|
||||
|
||||
return features
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
self, batch: dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from parent collator.
|
||||
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Process keys that need to be sliced
|
||||
for key in ["input_ids", "attention_mask", "labels"]:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.world_size
|
||||
start_idx = self.rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.rank < self.world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
|
||||
if key == "input_ids":
|
||||
# Before slicing
|
||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||
)
|
||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||
logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}")
|
||||
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
if key == "input_ids":
|
||||
# After slicing
|
||||
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
||||
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||
)
|
||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# Handle position_ids if present
|
||||
if "position_ids" in batch:
|
||||
pos_ids = batch["position_ids"]
|
||||
seq_len = pos_ids.shape[1]
|
||||
slice_size = seq_len // self.world_size
|
||||
start_idx = self.rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
||||
)
|
||||
|
||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||
|
||||
# Adjust position_ids to be relative to the slice start
|
||||
if self.rank > 0:
|
||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||
batch["position_ids"], start_idx
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
@@ -148,6 +280,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@@ -177,6 +310,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Module for sequence parallelism data collators."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.collators.batching import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: list | torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(position_ids, torch.Tensor):
|
||||
position_ids = torch.tensor(
|
||||
position_ids,
|
||||
device=position_ids.device if hasattr(position_ids, "device") else None,
|
||||
)
|
||||
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
# Process each sequence in the batch
|
||||
for i in range(position_ids.shape[0]):
|
||||
seq = position_ids[i]
|
||||
|
||||
# Find sample boundaries
|
||||
boundaries = []
|
||||
for j in range(1, len(seq)):
|
||||
if seq[j] < seq[j - 1]:
|
||||
boundaries.append(j)
|
||||
|
||||
# No need to adjust if there are no boundaries or this is a single sample
|
||||
if not boundaries:
|
||||
adjusted_pos_ids[i] = seq - start_idx
|
||||
continue
|
||||
|
||||
# Adjust each segment separately
|
||||
prev_boundary = 0
|
||||
for boundary in boundaries:
|
||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||
prev_boundary = boundary
|
||||
|
||||
# Last segment
|
||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||
|
||||
return adjusted_pos_ids
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
Mixin to add sequence parallelism slicing to data collators.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank(group=sp_group)
|
||||
self.world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
self, batch: dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from parent collator.
|
||||
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Process keys that need to be sliced
|
||||
for key in ["input_ids", "attention_mask", "labels"]:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.world_size
|
||||
start_idx = self.rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.rank < self.world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
|
||||
if key == "input_ids":
|
||||
# Before slicing
|
||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||
)
|
||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||
|
||||
# After slicing
|
||||
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
||||
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
# Handle position_ids if present
|
||||
if "position_ids" in batch:
|
||||
pos_ids = batch["position_ids"]
|
||||
seq_len = pos_ids.shape[1]
|
||||
slice_size = seq_len // self.world_size
|
||||
start_idx = self.rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
||||
)
|
||||
|
||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||
|
||||
# Adjust position_ids to be relative to the slice start
|
||||
if self.rank > 0:
|
||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||
batch["position_ids"], start_idx
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceParallelPackedDataCollator(
|
||||
SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq
|
||||
):
|
||||
"""
|
||||
Data collator for sequence parallelism with sample packing. Combines multiple
|
||||
samples into a packed sequence, then slices it for each GPU.
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
# Use the parent collator to handle sample packing and padding
|
||||
batch = super().__call__(features, return_tensors=return_tensors)
|
||||
return self.apply_sequence_parallelism(batch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class V2SequenceParallelPackedDataCollator(
|
||||
SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq
|
||||
):
|
||||
"""
|
||||
Data collator for sequence parallelism with V2 sample packing.
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
# Use the parent collator to handle sample packing and padding
|
||||
batch = super().__call__(features, return_tensors=return_tensors)
|
||||
return self.apply_sequence_parallelism(batch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Data collator for sequence parallelism without sample packing.
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
# Use the parent collator to pad everything correctly
|
||||
batch = super().__call__(features, return_tensors=return_tensors)
|
||||
return self.apply_sequence_parallelism(batch)
|
||||
Reference in New Issue
Block a user