updates
This commit is contained in:
@@ -32,6 +32,9 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
|
# pre-training a model from scratch or debugging purposes.
|
||||||
|
random_init:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
|
|||||||
@@ -871,10 +871,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator: Type[
|
collator: Type[
|
||||||
Union[
|
Union[
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
V2SequenceParallelPackedDataCollator,
|
||||||
|
SequenceParallelPackedDataCollator,
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
RewardDataCollatorWithPadding,
|
||||||
|
SequenceParallelDataCollator,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
collator_args = [self.tokenizer]
|
collator_args = [self.tokenizer]
|
||||||
|
|||||||
@@ -412,8 +412,21 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
sampler = SequentialSampler(self.train_dataset)
|
sampler = SequentialSampler(self.train_dataset)
|
||||||
else:
|
else:
|
||||||
|
generator = None
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(self.args.getattr("seed", 0))
|
||||||
|
|
||||||
sampler = RandomSampler(self.train_dataset)
|
sampler = RandomSampler(self.train_dataset)
|
||||||
|
|
||||||
|
# if dist.get_rank() == 0:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# dist.barrier()
|
||||||
|
|
||||||
|
# if dist.get_rank() == 1:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# dist.barrier()
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
sampler,
|
sampler,
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
@@ -426,7 +439,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
return SequentialSampler(self.train_dataset)
|
return SequentialSampler(self.train_dataset)
|
||||||
return super()._get_train_sampler()
|
|
||||||
|
sampler = super()._get_train_sampler()
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(self.args.getattr("seed", 0))
|
||||||
|
sampler.generator = generator
|
||||||
|
|
||||||
|
return sampler
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
self, eval_dataset: Dataset
|
self, eval_dataset: Dataset
|
||||||
@@ -478,6 +498,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
return self.accelerator.prepare_data_loader(
|
return self.accelerator.prepare_data_loader(
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
@@ -805,60 +831,36 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
num_items_in_batch=None,
|
num_items_in_batch: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
"""
|
"""
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_size > 1:
|
||||||
# At this point, inputs should already be partitioned by the sequence parallel data collator
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
# We'll just log some information about the partitioned data
|
# parallel data collator
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
|
||||||
# Get rank and SP information
|
# Get rank and SP information
|
||||||
sp_group = get_ring_attn_group()
|
sp_group = get_ring_attn_group()
|
||||||
rank = dist.get_rank()
|
|
||||||
sp_rank = dist.get_rank(group=sp_group) if sp_group else rank
|
|
||||||
world_size = (
|
world_size = (
|
||||||
dist.get_world_size(group=sp_group)
|
dist.get_world_size(group=sp_group)
|
||||||
if sp_group
|
if sp_group
|
||||||
else dist.get_world_size()
|
else dist.get_world_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sample tokens from our slice to verify partitioning
|
|
||||||
sample_start = (
|
|
||||||
inputs["input_ids"][0, :5].tolist()
|
|
||||||
if seq_len >= 5
|
|
||||||
else inputs["input_ids"][0, :].tolist()
|
|
||||||
)
|
|
||||||
sample_end = (
|
|
||||||
inputs["input_ids"][0, -5:].tolist()
|
|
||||||
if seq_len >= 5
|
|
||||||
else inputs["input_ids"][0, :].tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank} (SP rank {sp_rank}) | Step {self.state.global_step} | "
|
|
||||||
f"Slice shape: batch_size={batch_size}, seq_len={seq_len} | "
|
|
||||||
f"Sample start: {sample_start}, end: {sample_end}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
full_seq_len = seq_len * world_size
|
total_seq_len = seq_len * world_size
|
||||||
|
|
||||||
# Pass the partitioned sequence information to ring flash attention
|
# Pass the partitioned sequence information to ring flash attention
|
||||||
self._update_ring_flash_attn_params([seq_len] * batch_size, full_seq_len)
|
self._update_ring_flash_attn_params(
|
||||||
|
packed_seq_lens=[seq_len] * batch_size, total_seq_len=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
# Get the loss from the parent implementation
|
# Get the loss from the parent implementation
|
||||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
if self.args.sequence_parallel_size > 1:
|
|
||||||
rank = dist.get_rank()
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank} | Step {self.state.global_step} | Loss: {loss.item()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
"""Module for sequence parallelism data collators."""
|
"""Module for sequence parallelism data collators."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
from axolotl.utils.collators.batching import (
|
from axolotl.utils.collators.batching import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
@@ -14,32 +13,12 @@ from axolotl.utils.collators.batching import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
|
||||||
configure_logging()
|
logger = logging.getLogger(__name__)
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def find_sample_boundaries(position_ids):
|
def adjust_position_ids_for_slice(
|
||||||
"""
|
position_ids: list | torch.Tensor, start_idx: int
|
||||||
Find the boundaries between packed samples in a sequence by looking for
|
) -> torch.Tensor:
|
||||||
where position_ids decrease.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of boundary indices for each sequence in the batch
|
|
||||||
"""
|
|
||||||
batch_boundaries = []
|
|
||||||
|
|
||||||
for i in range(position_ids.shape[0]):
|
|
||||||
seq = position_ids[i]
|
|
||||||
boundaries = []
|
|
||||||
for j in range(1, len(seq)):
|
|
||||||
if seq[j] < seq[j - 1]:
|
|
||||||
boundaries.append(j)
|
|
||||||
batch_boundaries.append(boundaries)
|
|
||||||
|
|
||||||
return batch_boundaries
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_position_ids_for_slice(position_ids, start_idx):
|
|
||||||
"""
|
"""
|
||||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||||
This handles the case where position IDs might not be contiguous due to sample packing.
|
This handles the case where position IDs might not be contiguous due to sample packing.
|
||||||
@@ -64,370 +43,135 @@ def adjust_position_ids_for_slice(position_ids, start_idx):
|
|||||||
if seq[j] < seq[j - 1]:
|
if seq[j] < seq[j - 1]:
|
||||||
boundaries.append(j)
|
boundaries.append(j)
|
||||||
|
|
||||||
# Debug: log the found boundaries
|
|
||||||
LOG.debug(f"Sequence {i}: Found sample boundaries at positions {boundaries}")
|
|
||||||
|
|
||||||
# No need to adjust if there are no boundaries or this is a single sample
|
# No need to adjust if there are no boundaries or this is a single sample
|
||||||
if not boundaries:
|
if not boundaries:
|
||||||
old_values = seq[0:5].tolist() # Sample of original values
|
|
||||||
adjusted_pos_ids[i] = seq - start_idx
|
adjusted_pos_ids[i] = seq - start_idx
|
||||||
new_values = adjusted_pos_ids[i, 0:5].tolist() # Sample of new values
|
|
||||||
LOG.debug(
|
|
||||||
f"Sequence {i}: No boundaries, subtracting {start_idx} uniformly. Example values before: {old_values}, after: {new_values}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Adjust each segment separately
|
# Adjust each segment separately
|
||||||
prev_boundary = 0
|
prev_boundary = 0
|
||||||
for boundary_idx, boundary in enumerate(boundaries):
|
for boundary in boundaries:
|
||||||
segment = seq[prev_boundary:boundary]
|
|
||||||
old_values = segment[
|
|
||||||
0 : min(5, len(segment))
|
|
||||||
].tolist() # Sample of original values
|
|
||||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||||
new_values = adjusted_pos_ids[
|
|
||||||
i, prev_boundary : min(prev_boundary + 5, boundary)
|
|
||||||
].tolist() # Sample of new values
|
|
||||||
LOG.debug(
|
|
||||||
f"Sequence {i}, Segment {boundary_idx}: Adjusting positions {prev_boundary}-{boundary-1}. Example values before: {old_values}, after: {new_values}"
|
|
||||||
)
|
|
||||||
prev_boundary = boundary
|
prev_boundary = boundary
|
||||||
|
|
||||||
# Last segment
|
# Last segment
|
||||||
segment = seq[prev_boundary:]
|
|
||||||
old_values = segment[
|
|
||||||
0 : min(5, len(segment))
|
|
||||||
].tolist() # Sample of original values
|
|
||||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||||
new_values = adjusted_pos_ids[
|
|
||||||
i, prev_boundary : min(prev_boundary + 5, len(seq))
|
|
||||||
].tolist() # Sample of new values
|
|
||||||
LOG.debug(
|
|
||||||
f"Sequence {i}, Last segment: Adjusting positions {prev_boundary}-end. Example values before: {old_values}, after: {new_values}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return adjusted_pos_ids
|
return adjusted_pos_ids
|
||||||
|
|
||||||
|
|
||||||
def check_for_boundary_splits(boundaries, slice_start, slice_end):
|
class SequenceParallelMixin:
|
||||||
"""
|
"""
|
||||||
Check if any sample boundaries fall near the edge of a sequence slice.
|
Mixin to add sequence parallelism slicing to data collators.
|
||||||
These edge cases could cause issues with gradient computation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
boundaries: List of indices where sample boundaries occur
|
|
||||||
slice_start: Start index of this GPU's slice
|
|
||||||
slice_end: End index of this GPU's slice
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of potentially problematic boundaries
|
|
||||||
"""
|
|
||||||
# Consider a boundary "near" an edge if it's within 5 tokens
|
|
||||||
buffer_size = 5
|
|
||||||
problem_boundaries = []
|
|
||||||
|
|
||||||
for boundary in boundaries:
|
|
||||||
# Check if boundary is near the start of the slice
|
|
||||||
if slice_start <= boundary < slice_start + buffer_size:
|
|
||||||
problem_boundaries.append((boundary, "start", boundary - slice_start))
|
|
||||||
# Check if boundary is near the end of the slice
|
|
||||||
elif slice_end - buffer_size <= boundary < slice_end:
|
|
||||||
problem_boundaries.append((boundary, "end", slice_end - boundary))
|
|
||||||
|
|
||||||
return problem_boundaries
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceParallelPackedDataCollator(BatchSamplerDataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Data collator for sequence parallelism with sample packing.
|
|
||||||
Combines multiple samples into a packed sequence, then slices it for each GPU.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
debug_level: str = "debug" # Can be "debug" for more verbose output
|
def __post_init__(self):
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
# First, use the parent collator to handle sample packing and padding
|
|
||||||
batch = super().__call__(features, return_tensors=return_tensors)
|
|
||||||
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
if sp_group is None:
|
|
||||||
return batch # Not using sequence parallelism
|
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
# Get information about our position in the SP group
|
||||||
rank = dist.get_rank(group=sp_group)
|
sp_group = get_ring_attn_group()
|
||||||
world_size = dist.get_world_size(group=sp_group)
|
self.rank = dist.get_rank(group=sp_group)
|
||||||
|
self.world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
# Enable debug level if requested
|
def apply_sequence_parallelism(
|
||||||
if self.debug_level == "debug":
|
self, batch: dict[str, torch.Tensor]
|
||||||
original_shapes = {
|
) -> torch.Tensor:
|
||||||
k: v.shape if hasattr(v, "shape") else None for k, v in batch.items()
|
"""
|
||||||
}
|
Apply sequence parallelism slicing to a batch.
|
||||||
LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}")
|
|
||||||
|
|
||||||
if "position_ids" in batch:
|
Args:
|
||||||
# Find and log sample boundaries before slicing
|
batch: Batch dictionary from parent collator.
|
||||||
boundaries = find_sample_boundaries(batch["position_ids"])
|
|
||||||
for i, seq_boundaries in enumerate(boundaries):
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sliced batch dictionary.
|
||||||
|
"""
|
||||||
# Process keys that need to be sliced
|
# Process keys that need to be sliced
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
for key in ["input_ids", "attention_mask", "labels"]:
|
||||||
if key in batch:
|
if key in batch:
|
||||||
seq_len = batch[key].shape[1]
|
seq_len = batch[key].shape[1]
|
||||||
slice_size = seq_len // world_size
|
slice_size = seq_len // self.world_size
|
||||||
start_idx = rank * slice_size
|
start_idx = self.rank * slice_size
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
end_idx = (
|
||||||
|
start_idx + slice_size
|
||||||
LOG.info(
|
if self.rank < self.world_size - 1
|
||||||
f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})"
|
else seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_level == "debug" and key == "input_ids":
|
if key == "input_ids":
|
||||||
# Log portions of the input to verify correct slicing
|
# Before slicing
|
||||||
for i in range(
|
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||||
min(2, batch[key].shape[0])
|
logger.info(
|
||||||
): # Look at up to 2 sequences
|
f"GPU {self.rank}: Total sequence length: {seq_len}, "
|
||||||
# Sample the beginning, middle and end of the sequence before slicing
|
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||||
start_sample = batch[key][i, 0:5].tolist()
|
|
||||||
mid_sample = batch[key][
|
|
||||||
i, seq_len // 2 : seq_len // 2 + 5
|
|
||||||
].tolist()
|
|
||||||
end_sample = batch[key][i, -5:].tolist()
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}"
|
|
||||||
)
|
|
||||||
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
|
|
||||||
if self.debug_level == "debug" and key == "input_ids":
|
|
||||||
# Log after slicing to verify
|
|
||||||
for i in range(min(2, batch[key].shape[0])):
|
|
||||||
sliced_sample = batch[key][i, 0:5].tolist()
|
|
||||||
sliced_end = batch[key][i, -5:].tolist()
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}, Seq {i} after slicing: start={sliced_sample}, end={sliced_end}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle position_ids specially if present (important for packed sequences)
|
|
||||||
if "position_ids" in batch:
|
|
||||||
# For position_ids, we need to adjust them after slicing
|
|
||||||
# Each position_id should be relative to its slice
|
|
||||||
pos_ids = batch["position_ids"]
|
|
||||||
seq_len = pos_ids.shape[1]
|
|
||||||
slice_size = seq_len // world_size
|
|
||||||
start_idx = rank * slice_size
|
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
|
||||||
|
|
||||||
# Find boundaries before slicing
|
|
||||||
if self.debug_level == "debug":
|
|
||||||
full_boundaries = find_sample_boundaries(pos_ids)
|
|
||||||
|
|
||||||
# Check for boundaries that fall near slice edges
|
|
||||||
for i, boundaries in enumerate(full_boundaries):
|
|
||||||
problem_boundaries = check_for_boundary_splits(
|
|
||||||
boundaries, start_idx, end_idx
|
|
||||||
)
|
)
|
||||||
if problem_boundaries:
|
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
|
||||||
LOG.warning(
|
|
||||||
f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}"
|
|
||||||
)
|
|
||||||
|
|
||||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
# After slicing
|
||||||
|
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||||
# Find boundaries after slicing to verify correct transfer
|
logger.info(
|
||||||
if self.debug_level == "debug":
|
f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
|
||||||
sliced_boundaries = find_sample_boundaries(batch["position_ids"])
|
f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||||
for i, boundaries in enumerate(sliced_boundaries):
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}: After slicing, sequence {i} has boundaries at {boundaries}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Adjust position_ids to be relative to the start of this slice
|
dist.barrier()
|
||||||
# Only subtract if not the first GPU in the group
|
|
||||||
if rank > 0:
|
|
||||||
# Find boundaries between samples in the position_ids
|
|
||||||
# This preserves the sample packing structure
|
|
||||||
old_pos_ids = batch["position_ids"].clone()
|
|
||||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
|
||||||
batch["position_ids"], start_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.debug_level == "debug":
|
|
||||||
# Compare before and after adjustment
|
|
||||||
for i in range(min(2, old_pos_ids.shape[0])):
|
|
||||||
before = old_pos_ids[i, 0:10].tolist()
|
|
||||||
after = batch["position_ids"][i, 0:10].tolist()
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}, Seq {i} position_ids adjustment: before={before}, after={after}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add gradient norm tracking for debugging
|
|
||||||
if self.debug_level == "debug":
|
|
||||||
# Attach hook to track gradient norms during backward pass
|
|
||||||
def hook_fn(grad):
|
|
||||||
norm = grad.norm().item()
|
|
||||||
LOG.info(f"GPU {rank}: Gradient norm = {norm:.4f}")
|
|
||||||
# Record any abnormally high gradients
|
|
||||||
if norm > 10.0:
|
|
||||||
LOG.warning(f"GPU {rank}: High gradient norm detected: {norm:.4f}")
|
|
||||||
return grad
|
|
||||||
|
|
||||||
# Apply hook to input_ids embeddings if it goes through backward pass
|
|
||||||
if "input_ids" in batch and batch["input_ids"].requires_grad:
|
|
||||||
batch["input_ids"].register_hook(hook_fn)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class V2SequenceParallelPackedDataCollator(V2BatchSamplerDataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Data collator for sequence parallelism with V2 sample packing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
debug_level: str = "debug" # Can be "debug" for more verbose output
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
# Implementation similar to SequenceParallelPackedDataCollator with V2 base
|
|
||||||
# First, use the parent collator to handle sample packing and padding
|
|
||||||
batch = super().__call__(features, return_tensors=return_tensors)
|
|
||||||
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
if sp_group is None:
|
|
||||||
return batch # Not using sequence parallelism
|
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
|
||||||
rank = dist.get_rank(group=sp_group)
|
|
||||||
world_size = dist.get_world_size(group=sp_group)
|
|
||||||
|
|
||||||
# Enable debug level if requested
|
|
||||||
if self.debug_level == "debug":
|
|
||||||
original_shapes = {
|
|
||||||
k: v.shape if hasattr(v, "shape") else None for k, v in batch.items()
|
|
||||||
}
|
|
||||||
LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}")
|
|
||||||
|
|
||||||
if "position_ids" in batch:
|
|
||||||
# Find and log sample boundaries before slicing
|
|
||||||
boundaries = find_sample_boundaries(batch["position_ids"])
|
|
||||||
for i, seq_boundaries in enumerate(boundaries):
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process keys that need to be sliced
|
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
if key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // world_size
|
|
||||||
start_idx = rank * slice_size
|
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
|
||||||
|
|
||||||
if self.debug_level == "debug" and key == "input_ids":
|
|
||||||
# Log portions of the input to verify correct slicing
|
|
||||||
for i in range(
|
|
||||||
min(2, batch[key].shape[0])
|
|
||||||
): # Look at up to 2 sequences
|
|
||||||
# Sample the beginning, middle and end of the sequence before slicing
|
|
||||||
start_sample = batch[key][i, 0:5].tolist()
|
|
||||||
mid_sample = batch[key][
|
|
||||||
i, seq_len // 2 : seq_len // 2 + 5
|
|
||||||
].tolist()
|
|
||||||
end_sample = batch[key][i, -5:].tolist()
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}"
|
|
||||||
)
|
|
||||||
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Handle position_ids specially (same as in SequenceParallelPackedDataCollator)
|
|
||||||
if "position_ids" in batch:
|
|
||||||
# Implementation identical to the one in SequenceParallelPackedDataCollator
|
|
||||||
pos_ids = batch["position_ids"]
|
|
||||||
seq_len = pos_ids.shape[1]
|
|
||||||
slice_size = seq_len // world_size
|
|
||||||
start_idx = rank * slice_size
|
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
|
||||||
|
|
||||||
# Find boundaries before slicing
|
|
||||||
if self.debug_level == "debug":
|
|
||||||
full_boundaries = find_sample_boundaries(pos_ids)
|
|
||||||
|
|
||||||
# Check for boundaries that fall near slice edges
|
|
||||||
for i, boundaries in enumerate(full_boundaries):
|
|
||||||
problem_boundaries = check_for_boundary_splits(
|
|
||||||
boundaries, start_idx, end_idx
|
|
||||||
)
|
|
||||||
if problem_boundaries:
|
|
||||||
LOG.warning(
|
|
||||||
f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}"
|
|
||||||
)
|
|
||||||
|
|
||||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Adjust position_ids to be relative to the start of this slice
|
|
||||||
if rank > 0:
|
|
||||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
|
||||||
batch["position_ids"], start_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceParallelDataCollator(DataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Data collator for sequence parallelism without sample packing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
debug_level: str = "debug" # Can be "debug" for more verbose output
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
# First, use the parent collator to pad everything correctly
|
|
||||||
batch = super().__call__(features, return_tensors=return_tensors)
|
|
||||||
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
if sp_group is None:
|
|
||||||
return batch # Not using sequence parallelism
|
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
|
||||||
rank = dist.get_rank(group=sp_group)
|
|
||||||
world_size = dist.get_world_size(group=sp_group)
|
|
||||||
|
|
||||||
if self.debug_level == "debug":
|
|
||||||
original_shapes = {
|
|
||||||
k: v.shape if hasattr(v, "shape") else None for k, v in batch.items()
|
|
||||||
}
|
|
||||||
LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}")
|
|
||||||
|
|
||||||
# Process keys that need to be sliced
|
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
if key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // world_size
|
|
||||||
start_idx = rank * slice_size
|
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})"
|
|
||||||
)
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
# Handle position_ids if present
|
# Handle position_ids if present
|
||||||
if "position_ids" in batch:
|
if "position_ids" in batch:
|
||||||
pos_ids = batch["position_ids"]
|
pos_ids = batch["position_ids"]
|
||||||
seq_len = pos_ids.shape[1]
|
seq_len = pos_ids.shape[1]
|
||||||
slice_size = seq_len // world_size
|
slice_size = seq_len // self.world_size
|
||||||
start_idx = rank * slice_size
|
start_idx = self.rank * slice_size
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
end_idx = (
|
||||||
|
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
|
||||||
|
)
|
||||||
|
|
||||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||||
|
|
||||||
# For non-packed sequences, we can simply subtract start_idx from all position_ids
|
# Adjust position_ids to be relative to the slice start
|
||||||
if rank > 0:
|
if self.rank > 0:
|
||||||
batch["position_ids"] -= start_idx
|
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||||
|
batch["position_ids"], start_idx
|
||||||
|
)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceParallelPackedDataCollator(
|
||||||
|
SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Data collator for sequence parallelism with sample packing. Combines multiple
|
||||||
|
samples into a packed sequence, then slices it for each GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
# Use the parent collator to handle sample packing and padding
|
||||||
|
batch = super().__call__(features, return_tensors=return_tensors)
|
||||||
|
return self.apply_sequence_parallelism(batch)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class V2SequenceParallelPackedDataCollator(
|
||||||
|
SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Data collator for sequence parallelism with V2 sample packing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
# Use the parent collator to handle sample packing and padding
|
||||||
|
batch = super().__call__(features, return_tensors=return_tensors)
|
||||||
|
return self.apply_sequence_parallelism(batch)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Data collator for sequence parallelism without sample packing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
# Use the parent collator to pad everything correctly
|
||||||
|
batch = super().__call__(features, return_tensors=return_tensors)
|
||||||
|
return self.apply_sequence_parallelism(batch)
|
||||||
|
|||||||
@@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
|||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MULTIMODEL_AUTO_MODEL_MAPPING = {
|
||||||
|
"llava": LlavaForConditionalGeneration,
|
||||||
|
"mllama": MllamaForConditionalGeneration,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
@@ -476,7 +481,7 @@ class ModelLoader:
|
|||||||
else:
|
else:
|
||||||
self.text_model_config = self.model_config
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
@@ -612,7 +617,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def patch_llama_derived_model(self) -> None:
|
def patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block"""
|
"""Modify all llama derived models in one block"""
|
||||||
self.patch_loss_llama()
|
self.patch_loss_llama()
|
||||||
|
|
||||||
@@ -662,25 +667,16 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_auto_model_loader(self) -> None:
|
def set_auto_model_loader(self):
|
||||||
"""set self.AutoModelLoader
|
"""
|
||||||
- default value: AutoModelForCausalLM (set at __init__)
|
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
- when using a multi modality model, self.AutoModelLoader should
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
be set according to model type of the model
|
should be set according to the type of the model.
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
if self.model_config.model_type == "llava":
|
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
self.model_config.model_type, AutoModelForVision2Seq
|
||||||
LlavaForConditionalGeneration
|
)
|
||||||
)
|
|
||||||
elif self.model_config.model_type == "mllama":
|
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
|
||||||
MllamaForConditionalGeneration
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.AutoModelLoader = (
|
|
||||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_device_map_config(self) -> None:
|
def set_device_map_config(self) -> None:
|
||||||
device_map = self.cfg.device_map
|
device_map = self.cfg.device_map
|
||||||
@@ -704,7 +700,7 @@ class ModelLoader:
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = self.AutoModelLoader.from_config(
|
model_canvas = self.auto_model_loader.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
@@ -925,11 +921,27 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
|
||||||
self.base_model,
|
# Load model with random initialization if specified
|
||||||
config=self.model_config,
|
if self.cfg.random_init:
|
||||||
**self.model_kwargs,
|
# AutoModel classes support the from_config method
|
||||||
)
|
if self.auto_model_loader in [
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
]:
|
||||||
|
self.model = self.auto_model_loader.from_config(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
|
self.base_model,
|
||||||
|
config=self.model_config,
|
||||||
|
**self.model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO (MengqingCao) split these patches seperately
|
# TODO (MengqingCao) split these patches seperately
|
||||||
if self.cfg.flash_attention and not self.inference:
|
if self.cfg.flash_attention and not self.inference:
|
||||||
@@ -967,7 +979,7 @@ class ModelLoader:
|
|||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1000,7 +1012,7 @@ class ModelLoader:
|
|||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1020,7 +1032,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1316,7 +1328,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
loader = ModelLoader(
|
model_loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@@ -1324,7 +1336,7 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return loader.load_model()
|
return model_loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_size
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
total_num_steps = int(
|
||||||
|
math.floor(
|
||||||
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(
|
||||||
|
len(train_dataset)
|
||||||
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_size
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
|||||||
from axolotl.utils.collators.sequence_parallel import (
|
from axolotl.utils.collators.sequence_parallel import (
|
||||||
adjust_position_ids_for_slice,
|
adjust_position_ids_for_slice,
|
||||||
check_for_boundary_splits,
|
check_for_boundary_splits,
|
||||||
find_sample_boundaries,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -30,24 +29,6 @@ def partial_state():
|
|||||||
class TestSequenceParallelHelpers:
|
class TestSequenceParallelHelpers:
|
||||||
"""Test helper functions used in sequence parallelism."""
|
"""Test helper functions used in sequence parallelism."""
|
||||||
|
|
||||||
def test_find_sample_boundaries(self):
|
|
||||||
"""Test detection of boundaries in position_ids."""
|
|
||||||
# Create sample position_ids with multiple sequences
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
# First sequence with 2 samples (boundary at index 5)
|
|
||||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
|
||||||
# Second sequence with 3 samples (boundaries at 3 and 7)
|
|
||||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
boundaries = find_sample_boundaries(position_ids)
|
|
||||||
|
|
||||||
assert len(boundaries) == 2
|
|
||||||
assert boundaries[0] == [5] # First sequence has boundary at index 5
|
|
||||||
assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7
|
|
||||||
|
|
||||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||||
"""Test position_ids adjustment for sequence slices."""
|
"""Test position_ids adjustment for sequence slices."""
|
||||||
# Create sample position_ids with multiple sequences
|
# Create sample position_ids with multiple sequences
|
||||||
|
|||||||
Reference in New Issue
Block a user