This commit is contained in:
Dan Saunders
2025-03-10 21:18:04 +00:00
parent b44a207248
commit 4190ad0647
7 changed files with 187 additions and 432 deletions

View File

@@ -32,6 +32,9 @@ tokenizer_legacy:
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink. # Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings: shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init:
# (Internal use only) # (Internal use only)
# Used to identify which the model is based on # Used to identify which the model is based on

View File

@@ -871,10 +871,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator: Type[ collator: Type[
Union[ Union[
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
V2SequenceParallelPackedDataCollator,
SequenceParallelPackedDataCollator,
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
DataCollatorWithFlattening, DataCollatorWithFlattening,
RewardDataCollatorWithPadding, RewardDataCollatorWithPadding,
SequenceParallelDataCollator,
] ]
] ]
collator_args = [self.tokenizer] collator_args = [self.tokenizer]

View File

@@ -412,8 +412,21 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if self.args.curriculum_sampling: if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset) sampler = SequentialSampler(self.train_dataset)
else: else:
generator = None
if self.args.sequence_parallel_size > 1:
generator = torch.Generator()
generator.manual_seed(self.args.getattr("seed", 0))
sampler = RandomSampler(self.train_dataset) sampler = RandomSampler(self.train_dataset)
# if dist.get_rank() == 0:
# import ipdb; ipdb.set_trace()
# dist.barrier()
# if dist.get_rank() == 1:
# import ipdb; ipdb.set_trace()
# dist.barrier()
return MultipackBatchSampler( return MultipackBatchSampler(
sampler, sampler,
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset),
@@ -426,7 +439,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
) )
if self.args.curriculum_sampling: if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset) return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
sampler = super()._get_train_sampler()
if self.args.sequence_parallel_size > 1:
generator = torch.Generator()
generator.manual_seed(self.args.getattr("seed", 0))
sampler.generator = generator
return sampler
def _get_eval_sampler( def _get_eval_sampler(
self, eval_dataset: Dataset self, eval_dataset: Dataset
@@ -478,6 +498,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker dataloader_params["worker_init_fn"] = seed_worker
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
self.accelerator.even_batches = False self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader( return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params) DataLoader(train_dataset, **dataloader_params)
@@ -805,60 +831,36 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
self, self,
model: nn.Module, model: nn.Module,
inputs: dict[str, torch.Tensor | Any], inputs: dict[str, torch.Tensor | Any],
num_items_in_batch=None, num_items_in_batch: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
""" """
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_size > 1:
# At this point, inputs should already be partitioned by the sequence parallel data collator # At this point, inputs should already be partitioned by the sequence
# We'll just log some information about the partitioned data # parallel data collator
batch_size = inputs["input_ids"].shape[0] batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1] seq_len = inputs["input_ids"].shape[1]
# Get rank and SP information # Get rank and SP information
sp_group = get_ring_attn_group() sp_group = get_ring_attn_group()
rank = dist.get_rank()
sp_rank = dist.get_rank(group=sp_group) if sp_group else rank
world_size = ( world_size = (
dist.get_world_size(group=sp_group) dist.get_world_size(group=sp_group)
if sp_group if sp_group
else dist.get_world_size() else dist.get_world_size()
) )
# Sample tokens from our slice to verify partitioning
sample_start = (
inputs["input_ids"][0, :5].tolist()
if seq_len >= 5
else inputs["input_ids"][0, :].tolist()
)
sample_end = (
inputs["input_ids"][0, -5:].tolist()
if seq_len >= 5
else inputs["input_ids"][0, :].tolist()
)
LOG.info(
f"GPU {rank} (SP rank {sp_rank}) | Step {self.state.global_step} | "
f"Slice shape: batch_size={batch_size}, seq_len={seq_len} | "
f"Sample start: {sample_start}, end: {sample_end}"
)
# Calculate the full sequence length across all GPUs in this SP group # Calculate the full sequence length across all GPUs in this SP group
full_seq_len = seq_len * world_size total_seq_len = seq_len * world_size
# Pass the partitioned sequence information to ring flash attention # Pass the partitioned sequence information to ring flash attention
self._update_ring_flash_attn_params([seq_len] * batch_size, full_seq_len) self._update_ring_flash_attn_params(
packed_seq_lens=[seq_len] * batch_size, total_seq_len=total_seq_len
)
# Get the loss from the parent implementation # Get the loss from the parent implementation
loss = super().training_step(model, inputs, num_items_in_batch) loss = super().training_step(model, inputs, num_items_in_batch)
if self.args.sequence_parallel_size > 1:
rank = dist.get_rank()
LOG.info(
f"GPU {rank} | Step {self.state.global_step} | Loss: {loss.item()}"
)
return loss return loss
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len): def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):

View File

@@ -1,12 +1,11 @@
"""Module for sequence parallelism data collators.""" """Module for sequence parallelism data collators."""
import logging
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.utils.collators.batching import ( from axolotl.utils.collators.batching import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
@@ -14,32 +13,12 @@ from axolotl.utils.collators.batching import (
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
configure_logging() logger = logging.getLogger(__name__)
LOG = get_logger(__name__)
def find_sample_boundaries(position_ids): def adjust_position_ids_for_slice(
""" position_ids: list | torch.Tensor, start_idx: int
Find the boundaries between packed samples in a sequence by looking for ) -> torch.Tensor:
where position_ids decrease.
Returns:
List of boundary indices for each sequence in the batch
"""
batch_boundaries = []
for i in range(position_ids.shape[0]):
seq = position_ids[i]
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
batch_boundaries.append(boundaries)
return batch_boundaries
def adjust_position_ids_for_slice(position_ids, start_idx):
""" """
Adjust position IDs for a sliced sequence to maintain proper relative positions. Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample packing. This handles the case where position IDs might not be contiguous due to sample packing.
@@ -64,370 +43,135 @@ def adjust_position_ids_for_slice(position_ids, start_idx):
if seq[j] < seq[j - 1]: if seq[j] < seq[j - 1]:
boundaries.append(j) boundaries.append(j)
# Debug: log the found boundaries
LOG.debug(f"Sequence {i}: Found sample boundaries at positions {boundaries}")
# No need to adjust if there are no boundaries or this is a single sample # No need to adjust if there are no boundaries or this is a single sample
if not boundaries: if not boundaries:
old_values = seq[0:5].tolist() # Sample of original values
adjusted_pos_ids[i] = seq - start_idx adjusted_pos_ids[i] = seq - start_idx
new_values = adjusted_pos_ids[i, 0:5].tolist() # Sample of new values
LOG.debug(
f"Sequence {i}: No boundaries, subtracting {start_idx} uniformly. Example values before: {old_values}, after: {new_values}"
)
continue continue
# Adjust each segment separately # Adjust each segment separately
prev_boundary = 0 prev_boundary = 0
for boundary_idx, boundary in enumerate(boundaries): for boundary in boundaries:
segment = seq[prev_boundary:boundary]
old_values = segment[
0 : min(5, len(segment))
].tolist() # Sample of original values
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
new_values = adjusted_pos_ids[
i, prev_boundary : min(prev_boundary + 5, boundary)
].tolist() # Sample of new values
LOG.debug(
f"Sequence {i}, Segment {boundary_idx}: Adjusting positions {prev_boundary}-{boundary-1}. Example values before: {old_values}, after: {new_values}"
)
prev_boundary = boundary prev_boundary = boundary
# Last segment # Last segment
segment = seq[prev_boundary:]
old_values = segment[
0 : min(5, len(segment))
].tolist() # Sample of original values
adjusted_pos_ids[i, prev_boundary:] -= start_idx adjusted_pos_ids[i, prev_boundary:] -= start_idx
new_values = adjusted_pos_ids[
i, prev_boundary : min(prev_boundary + 5, len(seq))
].tolist() # Sample of new values
LOG.debug(
f"Sequence {i}, Last segment: Adjusting positions {prev_boundary}-end. Example values before: {old_values}, after: {new_values}"
)
return adjusted_pos_ids return adjusted_pos_ids
def check_for_boundary_splits(boundaries, slice_start, slice_end): class SequenceParallelMixin:
""" """
Check if any sample boundaries fall near the edge of a sequence slice. Mixin to add sequence parallelism slicing to data collators.
These edge cases could cause issues with gradient computation.
Args:
boundaries: List of indices where sample boundaries occur
slice_start: Start index of this GPU's slice
slice_end: End index of this GPU's slice
Returns:
List of potentially problematic boundaries
"""
# Consider a boundary "near" an edge if it's within 5 tokens
buffer_size = 5
problem_boundaries = []
for boundary in boundaries:
# Check if boundary is near the start of the slice
if slice_start <= boundary < slice_start + buffer_size:
problem_boundaries.append((boundary, "start", boundary - slice_start))
# Check if boundary is near the end of the slice
elif slice_end - buffer_size <= boundary < slice_end:
problem_boundaries.append((boundary, "end", slice_end - boundary))
return problem_boundaries
@dataclass
class SequenceParallelPackedDataCollator(BatchSamplerDataCollatorForSeq2Seq):
"""
Data collator for sequence parallelism with sample packing.
Combines multiple samples into a packed sequence, then slices it for each GPU.
""" """
debug_level: str = "debug" # Can be "debug" for more verbose output def __post_init__(self):
def __call__(self, features, return_tensors=None):
# First, use the parent collator to handle sample packing and padding
batch = super().__call__(features, return_tensors=return_tensors)
sp_group = get_ring_attn_group()
if sp_group is None:
return batch # Not using sequence parallelism
# Get information about our position in the SP group # Get information about our position in the SP group
rank = dist.get_rank(group=sp_group) sp_group = get_ring_attn_group()
world_size = dist.get_world_size(group=sp_group) self.rank = dist.get_rank(group=sp_group)
self.world_size = dist.get_world_size(group=sp_group)
# Enable debug level if requested def apply_sequence_parallelism(
if self.debug_level == "debug": self, batch: dict[str, torch.Tensor]
original_shapes = { ) -> torch.Tensor:
k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() """
} Apply sequence parallelism slicing to a batch.
LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}")
if "position_ids" in batch: Args:
# Find and log sample boundaries before slicing batch: Batch dictionary from parent collator.
boundaries = find_sample_boundaries(batch["position_ids"])
for i, seq_boundaries in enumerate(boundaries):
LOG.info(
f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}"
)
Returns:
Sliced batch dictionary.
"""
# Process keys that need to be sliced # Process keys that need to be sliced
for key in ["input_ids", "attention_mask", "labels"]: for key in ["input_ids", "attention_mask", "labels"]:
if key in batch: if key in batch:
seq_len = batch[key].shape[1] seq_len = batch[key].shape[1]
slice_size = seq_len // world_size slice_size = seq_len // self.world_size
start_idx = rank * slice_size start_idx = self.rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len end_idx = (
start_idx + slice_size
LOG.info( if self.rank < self.world_size - 1
f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})" else seq_len
) )
if self.debug_level == "debug" and key == "input_ids": if key == "input_ids":
# Log portions of the input to verify correct slicing # Before slicing
for i in range( non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
min(2, batch[key].shape[0]) logger.info(
): # Look at up to 2 sequences f"GPU {self.rank}: Total sequence length: {seq_len}, "
# Sample the beginning, middle and end of the sequence before slicing f"Non-padding tokens: {non_pad_tokens_total}"
start_sample = batch[key][i, 0:5].tolist()
mid_sample = batch[key][
i, seq_len // 2 : seq_len // 2 + 5
].tolist()
end_sample = batch[key][i, -5:].tolist()
LOG.info(
f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}"
)
batch[key] = batch[key][:, start_idx:end_idx]
if self.debug_level == "debug" and key == "input_ids":
# Log after slicing to verify
for i in range(min(2, batch[key].shape[0])):
sliced_sample = batch[key][i, 0:5].tolist()
sliced_end = batch[key][i, -5:].tolist()
LOG.info(
f"GPU {rank}, Seq {i} after slicing: start={sliced_sample}, end={sliced_end}"
)
# Handle position_ids specially if present (important for packed sequences)
if "position_ids" in batch:
# For position_ids, we need to adjust them after slicing
# Each position_id should be relative to its slice
pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
# Find boundaries before slicing
if self.debug_level == "debug":
full_boundaries = find_sample_boundaries(pos_ids)
# Check for boundaries that fall near slice edges
for i, boundaries in enumerate(full_boundaries):
problem_boundaries = check_for_boundary_splits(
boundaries, start_idx, end_idx
) )
if problem_boundaries: logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
LOG.warning(
f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}"
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx] # After slicing
non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
# Find boundaries after slicing to verify correct transfer logger.info(
if self.debug_level == "debug": f"GPU {self.rank}: Slice {start_idx}-{end_idx}, "
sliced_boundaries = find_sample_boundaries(batch["position_ids"]) f"Non-padding tokens in slice: {non_pad_tokens_slice}"
for i, boundaries in enumerate(sliced_boundaries):
LOG.info(
f"GPU {rank}: After slicing, sequence {i} has boundaries at {boundaries}"
) )
# Adjust position_ids to be relative to the start of this slice dist.barrier()
# Only subtract if not the first GPU in the group
if rank > 0:
# Find boundaries between samples in the position_ids
# This preserves the sample packing structure
old_pos_ids = batch["position_ids"].clone()
batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
if self.debug_level == "debug":
# Compare before and after adjustment
for i in range(min(2, old_pos_ids.shape[0])):
before = old_pos_ids[i, 0:10].tolist()
after = batch["position_ids"][i, 0:10].tolist()
LOG.info(
f"GPU {rank}, Seq {i} position_ids adjustment: before={before}, after={after}"
)
# Add gradient norm tracking for debugging
if self.debug_level == "debug":
# Attach hook to track gradient norms during backward pass
def hook_fn(grad):
norm = grad.norm().item()
LOG.info(f"GPU {rank}: Gradient norm = {norm:.4f}")
# Record any abnormally high gradients
if norm > 10.0:
LOG.warning(f"GPU {rank}: High gradient norm detected: {norm:.4f}")
return grad
# Apply hook to input_ids embeddings if it goes through backward pass
if "input_ids" in batch and batch["input_ids"].requires_grad:
batch["input_ids"].register_hook(hook_fn)
return batch
@dataclass
class V2SequenceParallelPackedDataCollator(V2BatchSamplerDataCollatorForSeq2Seq):
"""
Data collator for sequence parallelism with V2 sample packing.
"""
debug_level: str = "debug" # Can be "debug" for more verbose output
def __call__(self, features, return_tensors=None):
# Implementation similar to SequenceParallelPackedDataCollator with V2 base
# First, use the parent collator to handle sample packing and padding
batch = super().__call__(features, return_tensors=return_tensors)
sp_group = get_ring_attn_group()
if sp_group is None:
return batch # Not using sequence parallelism
# Get information about our position in the SP group
rank = dist.get_rank(group=sp_group)
world_size = dist.get_world_size(group=sp_group)
# Enable debug level if requested
if self.debug_level == "debug":
original_shapes = {
k: v.shape if hasattr(v, "shape") else None for k, v in batch.items()
}
LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}")
if "position_ids" in batch:
# Find and log sample boundaries before slicing
boundaries = find_sample_boundaries(batch["position_ids"])
for i, seq_boundaries in enumerate(boundaries):
LOG.info(
f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}"
)
# Process keys that need to be sliced
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
if self.debug_level == "debug" and key == "input_ids":
# Log portions of the input to verify correct slicing
for i in range(
min(2, batch[key].shape[0])
): # Look at up to 2 sequences
# Sample the beginning, middle and end of the sequence before slicing
start_sample = batch[key][i, 0:5].tolist()
mid_sample = batch[key][
i, seq_len // 2 : seq_len // 2 + 5
].tolist()
end_sample = batch[key][i, -5:].tolist()
LOG.info(
f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}"
)
batch[key] = batch[key][:, start_idx:end_idx]
# Handle position_ids specially (same as in SequenceParallelPackedDataCollator)
if "position_ids" in batch:
# Implementation identical to the one in SequenceParallelPackedDataCollator
pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
# Find boundaries before slicing
if self.debug_level == "debug":
full_boundaries = find_sample_boundaries(pos_ids)
# Check for boundaries that fall near slice edges
for i, boundaries in enumerate(full_boundaries):
problem_boundaries = check_for_boundary_splits(
boundaries, start_idx, end_idx
)
if problem_boundaries:
LOG.warning(
f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}"
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
# Adjust position_ids to be relative to the start of this slice
if rank > 0:
batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
return batch
@dataclass
class SequenceParallelDataCollator(DataCollatorForSeq2Seq):
"""
Data collator for sequence parallelism without sample packing.
"""
debug_level: str = "debug" # Can be "debug" for more verbose output
def __call__(self, features, return_tensors=None):
# First, use the parent collator to pad everything correctly
batch = super().__call__(features, return_tensors=return_tensors)
sp_group = get_ring_attn_group()
if sp_group is None:
return batch # Not using sequence parallelism
# Get information about our position in the SP group
rank = dist.get_rank(group=sp_group)
world_size = dist.get_world_size(group=sp_group)
if self.debug_level == "debug":
original_shapes = {
k: v.shape if hasattr(v, "shape") else None for k, v in batch.items()
}
LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}")
# Process keys that need to be sliced
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
LOG.info(
f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})"
)
batch[key] = batch[key][:, start_idx:end_idx] batch[key] = batch[key][:, start_idx:end_idx]
# Handle position_ids if present # Handle position_ids if present
if "position_ids" in batch: if "position_ids" in batch:
pos_ids = batch["position_ids"] pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1] seq_len = pos_ids.shape[1]
slice_size = seq_len // world_size slice_size = seq_len // self.world_size
start_idx = rank * slice_size start_idx = self.rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len end_idx = (
start_idx + slice_size if self.rank < self.world_size - 1 else seq_len
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx] batch["position_ids"] = pos_ids[:, start_idx:end_idx]
# For non-packed sequences, we can simply subtract start_idx from all position_ids # Adjust position_ids to be relative to the slice start
if rank > 0: if self.rank > 0:
batch["position_ids"] -= start_idx batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
return batch return batch
@dataclass
class SequenceParallelPackedDataCollator(
SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq
):
"""
Data collator for sequence parallelism with sample packing. Combines multiple
samples into a packed sequence, then slices it for each GPU.
"""
def __call__(self, features, return_tensors=None):
# Use the parent collator to handle sample packing and padding
batch = super().__call__(features, return_tensors=return_tensors)
return self.apply_sequence_parallelism(batch)
@dataclass
class V2SequenceParallelPackedDataCollator(
SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq
):
"""
Data collator for sequence parallelism with V2 sample packing.
"""
def __call__(self, features, return_tensors=None):
# Use the parent collator to handle sample packing and padding
batch = super().__call__(features, return_tensors=return_tensors)
return self.apply_sequence_parallelism(batch)
@dataclass
class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq):
"""
Data collator for sequence parallelism without sample packing.
"""
def __call__(self, features, return_tensors=None):
# Use the parent collator to pad everything correctly
batch = super().__call__(features, return_tensors=return_tensors)
return self.apply_sequence_parallelism(batch)

View File

@@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl") LOG = logging.getLogger(__name__)
MULTIMODEL_AUTO_MODEL_MAPPING = {
"llava": LlavaForConditionalGeneration,
"mllama": MllamaForConditionalGeneration,
}
# copied from accelerator.FullyShardedDataParallelPlugin # copied from accelerator.FullyShardedDataParallelPlugin
@@ -476,7 +481,7 @@ class ModelLoader:
else: else:
self.text_model_config = self.model_config self.text_model_config = self.model_config
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None: def apply_patches(self) -> None:
# load any patches from plugins # load any patches from plugins
@@ -612,7 +617,7 @@ class ModelLoader:
patch_self_attn_lora() patch_self_attn_lora()
def patch_llama_derived_model(self) -> None: def patch_llama_derived_model(self):
"""Modify all llama derived models in one block""" """Modify all llama derived models in one block"""
self.patch_loss_llama() self.patch_loss_llama()
@@ -662,25 +667,16 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention." "Shifted-sparse attention not currently implemented without flash attention."
) )
def set_auto_model_loader(self) -> None: def set_auto_model_loader(self):
"""set self.AutoModelLoader """
- default value: AutoModelForCausalLM (set at __init__) Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
- when using a multi modality model, self.AutoModelLoader should (set at `__init__`). When using a multimodal model, `self.auto_model_loader`
be set according to model type of the model should be set according to the type of the model.
""" """
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
if self.model_config.model_type == "llava": self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
self.AutoModelLoader = ( # pylint: disable=invalid-name self.model_config.model_type, AutoModelForVision2Seq
LlavaForConditionalGeneration )
)
elif self.model_config.model_type == "mllama":
self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name
)
def set_device_map_config(self) -> None: def set_device_map_config(self) -> None:
device_map = self.cfg.device_map device_map = self.cfg.device_map
@@ -704,7 +700,7 @@ class ModelLoader:
from accelerate import infer_auto_device_map from accelerate import infer_auto_device_map
with init_empty_weights(): with init_empty_weights():
model_canvas = self.AutoModelLoader.from_config( model_canvas = self.auto_model_loader.from_config(
self.model_config, self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False, trust_remote_code=self.cfg.trust_remote_code or False,
) )
@@ -925,11 +921,27 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model, # Load model with random initialization if specified
config=self.model_config, if self.cfg.random_init:
**self.model_kwargs, # AutoModel classes support the from_config method
) if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(
config=self.model_config,
)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
# TODO (MengqingCao) split these patches seperately # TODO (MengqingCao) split these patches seperately
if self.cfg.flash_attention and not self.inference: if self.cfg.flash_attention and not self.inference:
@@ -967,7 +979,7 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
if self.cfg.gptq: if self.cfg.gptq:
self.model = self.AutoModelLoader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False, trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1000,7 +1012,7 @@ class ModelLoader:
if self.cfg.gptq: if self.cfg.gptq:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False, trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1020,7 +1032,7 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False, trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1316,7 +1328,7 @@ def load_model(
""" """
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
loader = ModelLoader( model_loader = ModelLoader(
cfg, cfg,
tokenizer, tokenizer,
processor=processor, processor=processor,
@@ -1324,7 +1336,7 @@ def load_model(
reference_model=reference_model, reference_model=reference_model,
**kwargs, **kwargs,
) )
return loader.load_model() return model_loader.load_model()
def load_adapter(model, cfg, adapter, inference=False): def load_adapter(model, cfg, adapter, inference=False):

View File

@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1 - 1
) )
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_size
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends # FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size
)
)
def calc_sample_packing_eff_est(estimates: List[float]): def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
) )
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_size
/ cfg.batch_size
)
) )
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps return total_num_steps

View File

@@ -14,7 +14,6 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
from axolotl.utils.collators.sequence_parallel import ( from axolotl.utils.collators.sequence_parallel import (
adjust_position_ids_for_slice, adjust_position_ids_for_slice,
check_for_boundary_splits, check_for_boundary_splits,
find_sample_boundaries,
) )
@@ -30,24 +29,6 @@ def partial_state():
class TestSequenceParallelHelpers: class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism.""" """Test helper functions used in sequence parallelism."""
def test_find_sample_boundaries(self):
"""Test detection of boundaries in position_ids."""
# Create sample position_ids with multiple sequences
position_ids = torch.tensor(
[
# First sequence with 2 samples (boundary at index 5)
[0, 1, 2, 3, 4, 0, 1, 2, 3],
# Second sequence with 3 samples (boundaries at 3 and 7)
[0, 1, 2, 0, 1, 2, 3, 0, 1],
]
)
boundaries = find_sample_boundaries(position_ids)
assert len(boundaries) == 2
assert boundaries[0] == [5] # First sequence has boundary at index 5
assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7
def test_adjust_position_ids_for_slice(self, partial_state): def test_adjust_position_ids_for_slice(self, partial_state):
"""Test position_ids adjustment for sequence slices.""" """Test position_ids adjustment for sequence slices."""
# Create sample position_ids with multiple sequences # Create sample position_ids with multiple sequences