fixes
This commit is contained in:
@@ -80,10 +80,7 @@ from axolotl.utils.collators import (
|
|||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
SequenceParallelDataCollator,
|
|
||||||
SequenceParallelPackedDataCollator,
|
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2SequenceParallelPackedDataCollator,
|
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
@@ -871,13 +868,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator: Type[
|
collator: Type[
|
||||||
Union[
|
Union[
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2SequenceParallelPackedDataCollator,
|
|
||||||
SequenceParallelPackedDataCollator,
|
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
RewardDataCollatorWithPadding,
|
||||||
SequenceParallelDataCollator,
|
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
collator_args = [self.tokenizer]
|
collator_args = [self.tokenizer]
|
||||||
@@ -890,15 +884,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
if self.cfg.sequence_parallel_size > 1:
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
collator = V2SequenceParallelPackedDataCollator
|
|
||||||
else:
|
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
else:
|
else:
|
||||||
if self.cfg.sequence_parallel_size > 1:
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
collator = SequenceParallelPackedDataCollator
|
|
||||||
else:
|
|
||||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
else:
|
else:
|
||||||
if self.cfg.processor_type and self.processor:
|
if self.cfg.processor_type and self.processor:
|
||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
@@ -920,12 +908,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
collator = DataCollatorForKD
|
collator = DataCollatorForKD
|
||||||
else:
|
else:
|
||||||
if self.cfg.sequence_parallel_size > 1:
|
collator = DataCollatorForSeq2Seq
|
||||||
collator = SequenceParallelDataCollator
|
|
||||||
else:
|
|
||||||
collator = DataCollatorForSeq2Seq
|
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
|
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from collections import defaultdict
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -20,7 +21,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@@ -415,17 +416,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
generator = None
|
generator = None
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_size > 1:
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(self.args.getattr("seed", 0))
|
generator.manual_seed(self.args.seed)
|
||||||
|
sampler = RandomSampler(
|
||||||
sampler = RandomSampler(self.train_dataset)
|
self.train_dataset, generator=generator
|
||||||
|
)
|
||||||
# if dist.get_rank() == 0:
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# dist.barrier()
|
|
||||||
|
|
||||||
# if dist.get_rank() == 1:
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# dist.barrier()
|
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
sampler,
|
sampler,
|
||||||
@@ -443,7 +437,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
sampler = super()._get_train_sampler()
|
sampler = super()._get_train_sampler()
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_size > 1:
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(self.args.getattr("seed", 0))
|
generator.manual_seed(self.args.seed)
|
||||||
sampler.generator = generator
|
sampler.generator = generator
|
||||||
|
|
||||||
return sampler
|
return sampler
|
||||||
@@ -473,17 +467,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator
|
||||||
|
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self._train_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||||
|
}
|
||||||
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
train_dataset = self.train_dataset
|
|
||||||
if "length" in train_dataset.features.keys():
|
if "length" in train_dataset.features.keys():
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params["prefetch_factor"] = (
|
dataloader_params["prefetch_factor"] = (
|
||||||
self.args.dataloader_prefetch_factor
|
self.args.dataloader_prefetch_factor
|
||||||
@@ -498,17 +495,31 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
return self.accelerator.prepare_data_loader(
|
if self.args.sequence_parallel_size > 1:
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
return DataLoader(train_dataset, **dataloader_params)
|
||||||
)
|
else:
|
||||||
return super().get_train_dataloader()
|
return self.accelerator.prepare_data_loader(
|
||||||
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||||
|
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||||
|
|
||||||
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
|
dataloader_params["sampler"] = self._get_train_sampler()
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
return DataLoader(train_dataset, **dataloader_params)
|
||||||
|
else:
|
||||||
|
return self.accelerator.prepare(
|
||||||
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
|
)
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
|
|||||||
@@ -9,8 +9,3 @@ from .batching import ( # noqa: F401
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from .mamba import MambaDataCollator # noqa: F401
|
from .mamba import MambaDataCollator # noqa: F401
|
||||||
from .sequence_parallel import ( # noqa: F401
|
|
||||||
SequenceParallelDataCollator,
|
|
||||||
SequenceParallelPackedDataCollator,
|
|
||||||
V2SequenceParallelPackedDataCollator,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,14 +1,66 @@
|
|||||||
"""
|
"""
|
||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||||
|
includes logic for handling sequence parallelism collation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_position_ids_for_slice(
|
||||||
|
position_ids: list | torch.Tensor, start_idx: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||||
|
This handles the case where position IDs might not be contiguous due to sample packing.
|
||||||
|
"""
|
||||||
|
# Convert to tensor if not already
|
||||||
|
if not isinstance(position_ids, torch.Tensor):
|
||||||
|
position_ids = torch.tensor(
|
||||||
|
position_ids,
|
||||||
|
device=position_ids.device if hasattr(position_ids, "device") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the boundaries between samples (where position_ids reset)
|
||||||
|
adjusted_pos_ids = position_ids.clone()
|
||||||
|
|
||||||
|
# Process each sequence in the batch
|
||||||
|
for i in range(position_ids.shape[0]):
|
||||||
|
seq = position_ids[i]
|
||||||
|
|
||||||
|
# Find sample boundaries
|
||||||
|
boundaries = []
|
||||||
|
for j in range(1, len(seq)):
|
||||||
|
if seq[j] < seq[j - 1]:
|
||||||
|
boundaries.append(j)
|
||||||
|
|
||||||
|
# No need to adjust if there are no boundaries or this is a single sample
|
||||||
|
if not boundaries:
|
||||||
|
adjusted_pos_ids[i] = seq - start_idx
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Adjust each segment separately
|
||||||
|
prev_boundary = 0
|
||||||
|
for boundary in boundaries:
|
||||||
|
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||||
|
prev_boundary = boundary
|
||||||
|
|
||||||
|
# Last segment
|
||||||
|
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||||
|
|
||||||
|
return adjusted_pos_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -43,6 +95,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
sequence_parallel_size (`int`):
|
||||||
|
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -53,6 +107,14 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
sequence_parallel_size: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.sequence_parallel_size > 1:
|
||||||
|
# Get information about our position in the SP group
|
||||||
|
sp_group = get_ring_attn_group()
|
||||||
|
self.rank = dist.get_rank(group=sp_group)
|
||||||
|
self.world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -119,8 +181,78 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
if self.sequence_parallel_size > 1:
|
||||||
|
features = self.apply_sequence_parallelism(features)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def apply_sequence_parallelism(
|
||||||
|
self, batch: dict[str, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Batch dictionary from parent collator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sliced batch dictionary.
|
||||||
|
"""
|
||||||
|
# Process keys that need to be sliced
|
||||||
|
for key in ["input_ids", "attention_mask", "labels"]:
|
||||||
|
if key in batch:
|
||||||
|
seq_len = batch[key].shape[1]
|
||||||
|
slice_size = seq_len // self.world_size
|
||||||
|
start_idx = self.rank * slice_size
|
||||||
|
end_idx = (
|
||||||
|
start_idx + slice_size
|
||||||
|
if self.rank < self.world_size - 1
|
||||||
|
else seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
if key == "input_ids":
|
||||||
|
# Before slicing
|
||||||
|
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||||
|
logger.info(
|
||||||
|
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
||||||
|
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||||
|
)
|
||||||
|
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||||
|
logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}")
|
||||||
|
|
||||||
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
|
if key == "input_ids":
|
||||||
|
# After slicing
|
||||||
|
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||||
|
logger.info(
|
||||||
|
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
||||||
|
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||||
|
)
|
||||||
|
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# Handle position_ids if present
|
||||||
|
if "position_ids" in batch:
|
||||||
|
pos_ids = batch["position_ids"]
|
||||||
|
seq_len = pos_ids.shape[1]
|
||||||
|
slice_size = seq_len // self.world_size
|
||||||
|
start_idx = self.rank * slice_size
|
||||||
|
end_idx = (
|
||||||
|
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||||
|
|
||||||
|
# Adjust position_ids to be relative to the slice start
|
||||||
|
if self.rank > 0:
|
||||||
|
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||||
|
batch["position_ids"], start_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
@@ -148,6 +280,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,6 +310,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,177 +0,0 @@
|
|||||||
"""Module for sequence parallelism data collators."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
from axolotl.utils.collators.batching import (
|
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
DataCollatorForSeq2Seq,
|
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_position_ids_for_slice(
|
|
||||||
position_ids: list | torch.Tensor, start_idx: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
|
||||||
This handles the case where position IDs might not be contiguous due to sample packing.
|
|
||||||
"""
|
|
||||||
# Convert to tensor if not already
|
|
||||||
if not isinstance(position_ids, torch.Tensor):
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
position_ids,
|
|
||||||
device=position_ids.device if hasattr(position_ids, "device") else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find the boundaries between samples (where position_ids reset)
|
|
||||||
adjusted_pos_ids = position_ids.clone()
|
|
||||||
|
|
||||||
# Process each sequence in the batch
|
|
||||||
for i in range(position_ids.shape[0]):
|
|
||||||
seq = position_ids[i]
|
|
||||||
|
|
||||||
# Find sample boundaries
|
|
||||||
boundaries = []
|
|
||||||
for j in range(1, len(seq)):
|
|
||||||
if seq[j] < seq[j - 1]:
|
|
||||||
boundaries.append(j)
|
|
||||||
|
|
||||||
# No need to adjust if there are no boundaries or this is a single sample
|
|
||||||
if not boundaries:
|
|
||||||
adjusted_pos_ids[i] = seq - start_idx
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Adjust each segment separately
|
|
||||||
prev_boundary = 0
|
|
||||||
for boundary in boundaries:
|
|
||||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
|
||||||
prev_boundary = boundary
|
|
||||||
|
|
||||||
# Last segment
|
|
||||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
|
||||||
|
|
||||||
return adjusted_pos_ids
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelMixin:
|
|
||||||
"""
|
|
||||||
Mixin to add sequence parallelism slicing to data collators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Get information about our position in the SP group
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
self.rank = dist.get_rank(group=sp_group)
|
|
||||||
self.world_size = dist.get_world_size(group=sp_group)
|
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
|
||||||
self, batch: dict[str, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply sequence parallelism slicing to a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Batch dictionary from parent collator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sliced batch dictionary.
|
|
||||||
"""
|
|
||||||
# Process keys that need to be sliced
|
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
if key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // self.world_size
|
|
||||||
start_idx = self.rank * slice_size
|
|
||||||
end_idx = (
|
|
||||||
start_idx + slice_size
|
|
||||||
if self.rank < self.world_size - 1
|
|
||||||
else seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
if key == "input_ids":
|
|
||||||
# Before slicing
|
|
||||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
|
||||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
|
||||||
)
|
|
||||||
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
|
||||||
|
|
||||||
# After slicing
|
|
||||||
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
|
||||||
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
|
||||||
)
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Handle position_ids if present
|
|
||||||
if "position_ids" in batch:
|
|
||||||
pos_ids = batch["position_ids"]
|
|
||||||
seq_len = pos_ids.shape[1]
|
|
||||||
slice_size = seq_len // self.world_size
|
|
||||||
start_idx = self.rank * slice_size
|
|
||||||
end_idx = (
|
|
||||||
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Adjust position_ids to be relative to the slice start
|
|
||||||
if self.rank > 0:
|
|
||||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
|
||||||
batch["position_ids"], start_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceParallelPackedDataCollator(
|
|
||||||
SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Data collator for sequence parallelism with sample packing. Combines multiple
|
|
||||||
samples into a packed sequence, then slices it for each GPU.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
# Use the parent collator to handle sample packing and padding
|
|
||||||
batch = super().__call__(features, return_tensors=return_tensors)
|
|
||||||
return self.apply_sequence_parallelism(batch)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class V2SequenceParallelPackedDataCollator(
|
|
||||||
SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Data collator for sequence parallelism with V2 sample packing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
# Use the parent collator to handle sample packing and padding
|
|
||||||
batch = super().__call__(features, return_tensors=return_tensors)
|
|
||||||
return self.apply_sequence_parallelism(batch)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Data collator for sequence parallelism without sample packing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
# Use the parent collator to pad everything correctly
|
|
||||||
batch = super().__call__(features, return_tensors=return_tensors)
|
|
||||||
return self.apply_sequence_parallelism(batch)
|
|
||||||
Reference in New Issue
Block a user