diff --git a/docs/config.qmd b/docs/config.qmd index ea7ea2293..73aada105 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -32,6 +32,9 @@ tokenizer_legacy: resize_token_embeddings_to_32x: # Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink. shrink_embeddings: +# Whether to load the model with randomly initialized weights. Useful for +# pre-training a model from scratch or debugging purposes. +random_init: # (Internal use only) # Used to identify which the model is based on diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2e2883c89..4fe065671 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -871,10 +871,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator: Type[ Union[ V2BatchSamplerDataCollatorForSeq2Seq, + V2SequenceParallelPackedDataCollator, + SequenceParallelPackedDataCollator, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, DataCollatorWithFlattening, RewardDataCollatorWithPadding, + SequenceParallelDataCollator, ] ] collator_args = [self.tokenizer] diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 2f61a6da4..a5215114a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -412,8 +412,21 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.curriculum_sampling: sampler = SequentialSampler(self.train_dataset) else: + generator = None + if self.args.sequence_parallel_size > 1: + generator = torch.Generator() + generator.manual_seed(self.args.getattr("seed", 0)) + sampler = RandomSampler(self.train_dataset) + # if dist.get_rank() == 0: + # import ipdb; ipdb.set_trace() + # dist.barrier() + + # if dist.get_rank() == 1: + # import ipdb; ipdb.set_trace() + # dist.barrier() + return MultipackBatchSampler( sampler, lengths=get_dataset_lengths(self.train_dataset), @@ -426,7 +439,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): ) if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) - return super()._get_train_sampler() + + sampler = super()._get_train_sampler() + if self.args.sequence_parallel_size > 1: + generator = torch.Generator() + generator.manual_seed(self.args.getattr("seed", 0)) + sampler.generator = generator + + return sampler def _get_eval_sampler( self, eval_dataset: Dataset @@ -478,6 +498,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker + if dist.get_rank() == 0: + import ipdb + + ipdb.set_trace() + dist.barrier() + self.accelerator.even_batches = False return self.accelerator.prepare_data_loader( DataLoader(train_dataset, **dataloader_params) @@ -805,60 +831,36 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], - num_items_in_batch=None, + num_items_in_batch: int = None, ) -> torch.Tensor: """ Perform a training step on a batch of inputs. """ if self.args.sequence_parallel_size > 1: - # At this point, inputs should already be partitioned by the sequence parallel data collator - # We'll just log some information about the partitioned data + # At this point, inputs should already be partitioned by the sequence + # parallel data collator batch_size = inputs["input_ids"].shape[0] seq_len = inputs["input_ids"].shape[1] # Get rank and SP information sp_group = get_ring_attn_group() - rank = dist.get_rank() - sp_rank = dist.get_rank(group=sp_group) if sp_group else rank world_size = ( dist.get_world_size(group=sp_group) if sp_group else dist.get_world_size() ) - # Sample tokens from our slice to verify partitioning - sample_start = ( - inputs["input_ids"][0, :5].tolist() - if seq_len >= 5 - else inputs["input_ids"][0, :].tolist() - ) - sample_end = ( - inputs["input_ids"][0, -5:].tolist() - if seq_len >= 5 - else inputs["input_ids"][0, :].tolist() - ) - - LOG.info( - f"GPU {rank} (SP rank {sp_rank}) | Step {self.state.global_step} | " - f"Slice shape: batch_size={batch_size}, seq_len={seq_len} | " - f"Sample start: {sample_start}, end: {sample_end}" - ) - # Calculate the full sequence length across all GPUs in this SP group - full_seq_len = seq_len * world_size + total_seq_len = seq_len * world_size # Pass the partitioned sequence information to ring flash attention - self._update_ring_flash_attn_params([seq_len] * batch_size, full_seq_len) + self._update_ring_flash_attn_params( + packed_seq_lens=[seq_len] * batch_size, total_seq_len=total_seq_len + ) # Get the loss from the parent implementation loss = super().training_step(model, inputs, num_items_in_batch) - if self.args.sequence_parallel_size > 1: - rank = dist.get_rank() - LOG.info( - f"GPU {rank} | Step {self.state.global_step} | Loss: {loss.item()}" - ) - return loss def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len): diff --git a/src/axolotl/utils/collators/sequence_parallel.py b/src/axolotl/utils/collators/sequence_parallel.py index b41a07b60..1ed2957ca 100644 --- a/src/axolotl/utils/collators/sequence_parallel.py +++ b/src/axolotl/utils/collators/sequence_parallel.py @@ -1,12 +1,11 @@ """Module for sequence parallelism data collators.""" +import logging from dataclasses import dataclass import torch import torch.distributed as dist -from accelerate.logging import get_logger -from axolotl.logging_config import configure_logging from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group from axolotl.utils.collators.batching import ( BatchSamplerDataCollatorForSeq2Seq, @@ -14,32 +13,12 @@ from axolotl.utils.collators.batching import ( V2BatchSamplerDataCollatorForSeq2Seq, ) -configure_logging() -LOG = get_logger(__name__) +logger = logging.getLogger(__name__) -def find_sample_boundaries(position_ids): - """ - Find the boundaries between packed samples in a sequence by looking for - where position_ids decrease. - - Returns: - List of boundary indices for each sequence in the batch - """ - batch_boundaries = [] - - for i in range(position_ids.shape[0]): - seq = position_ids[i] - boundaries = [] - for j in range(1, len(seq)): - if seq[j] < seq[j - 1]: - boundaries.append(j) - batch_boundaries.append(boundaries) - - return batch_boundaries - - -def adjust_position_ids_for_slice(position_ids, start_idx): +def adjust_position_ids_for_slice( + position_ids: list | torch.Tensor, start_idx: int +) -> torch.Tensor: """ Adjust position IDs for a sliced sequence to maintain proper relative positions. This handles the case where position IDs might not be contiguous due to sample packing. @@ -64,370 +43,135 @@ def adjust_position_ids_for_slice(position_ids, start_idx): if seq[j] < seq[j - 1]: boundaries.append(j) - # Debug: log the found boundaries - LOG.debug(f"Sequence {i}: Found sample boundaries at positions {boundaries}") - # No need to adjust if there are no boundaries or this is a single sample if not boundaries: - old_values = seq[0:5].tolist() # Sample of original values adjusted_pos_ids[i] = seq - start_idx - new_values = adjusted_pos_ids[i, 0:5].tolist() # Sample of new values - LOG.debug( - f"Sequence {i}: No boundaries, subtracting {start_idx} uniformly. Example values before: {old_values}, after: {new_values}" - ) continue # Adjust each segment separately prev_boundary = 0 - for boundary_idx, boundary in enumerate(boundaries): - segment = seq[prev_boundary:boundary] - old_values = segment[ - 0 : min(5, len(segment)) - ].tolist() # Sample of original values + for boundary in boundaries: adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx - new_values = adjusted_pos_ids[ - i, prev_boundary : min(prev_boundary + 5, boundary) - ].tolist() # Sample of new values - LOG.debug( - f"Sequence {i}, Segment {boundary_idx}: Adjusting positions {prev_boundary}-{boundary-1}. Example values before: {old_values}, after: {new_values}" - ) prev_boundary = boundary # Last segment - segment = seq[prev_boundary:] - old_values = segment[ - 0 : min(5, len(segment)) - ].tolist() # Sample of original values adjusted_pos_ids[i, prev_boundary:] -= start_idx - new_values = adjusted_pos_ids[ - i, prev_boundary : min(prev_boundary + 5, len(seq)) - ].tolist() # Sample of new values - LOG.debug( - f"Sequence {i}, Last segment: Adjusting positions {prev_boundary}-end. Example values before: {old_values}, after: {new_values}" - ) return adjusted_pos_ids -def check_for_boundary_splits(boundaries, slice_start, slice_end): +class SequenceParallelMixin: """ - Check if any sample boundaries fall near the edge of a sequence slice. - These edge cases could cause issues with gradient computation. - - Args: - boundaries: List of indices where sample boundaries occur - slice_start: Start index of this GPU's slice - slice_end: End index of this GPU's slice - - Returns: - List of potentially problematic boundaries - """ - # Consider a boundary "near" an edge if it's within 5 tokens - buffer_size = 5 - problem_boundaries = [] - - for boundary in boundaries: - # Check if boundary is near the start of the slice - if slice_start <= boundary < slice_start + buffer_size: - problem_boundaries.append((boundary, "start", boundary - slice_start)) - # Check if boundary is near the end of the slice - elif slice_end - buffer_size <= boundary < slice_end: - problem_boundaries.append((boundary, "end", slice_end - boundary)) - - return problem_boundaries - - -@dataclass -class SequenceParallelPackedDataCollator(BatchSamplerDataCollatorForSeq2Seq): - """ - Data collator for sequence parallelism with sample packing. - Combines multiple samples into a packed sequence, then slices it for each GPU. + Mixin to add sequence parallelism slicing to data collators. """ - debug_level: str = "debug" # Can be "debug" for more verbose output - - def __call__(self, features, return_tensors=None): - # First, use the parent collator to handle sample packing and padding - batch = super().__call__(features, return_tensors=return_tensors) - - sp_group = get_ring_attn_group() - if sp_group is None: - return batch # Not using sequence parallelism - + def __post_init__(self): # Get information about our position in the SP group - rank = dist.get_rank(group=sp_group) - world_size = dist.get_world_size(group=sp_group) + sp_group = get_ring_attn_group() + self.rank = dist.get_rank(group=sp_group) + self.world_size = dist.get_world_size(group=sp_group) - # Enable debug level if requested - if self.debug_level == "debug": - original_shapes = { - k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() - } - LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}") + def apply_sequence_parallelism( + self, batch: dict[str, torch.Tensor] + ) -> torch.Tensor: + """ + Apply sequence parallelism slicing to a batch. - if "position_ids" in batch: - # Find and log sample boundaries before slicing - boundaries = find_sample_boundaries(batch["position_ids"]) - for i, seq_boundaries in enumerate(boundaries): - LOG.info( - f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}" - ) + Args: + batch: Batch dictionary from parent collator. + Returns: + Sliced batch dictionary. + """ # Process keys that need to be sliced for key in ["input_ids", "attention_mask", "labels"]: if key in batch: seq_len = batch[key].shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - - LOG.info( - f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})" + slice_size = seq_len // self.world_size + start_idx = self.rank * slice_size + end_idx = ( + start_idx + slice_size + if self.rank < self.world_size - 1 + else seq_len ) - if self.debug_level == "debug" and key == "input_ids": - # Log portions of the input to verify correct slicing - for i in range( - min(2, batch[key].shape[0]) - ): # Look at up to 2 sequences - # Sample the beginning, middle and end of the sequence before slicing - start_sample = batch[key][i, 0:5].tolist() - mid_sample = batch[key][ - i, seq_len // 2 : seq_len // 2 + 5 - ].tolist() - end_sample = batch[key][i, -5:].tolist() - LOG.info( - f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}" - ) - - batch[key] = batch[key][:, start_idx:end_idx] - - if self.debug_level == "debug" and key == "input_ids": - # Log after slicing to verify - for i in range(min(2, batch[key].shape[0])): - sliced_sample = batch[key][i, 0:5].tolist() - sliced_end = batch[key][i, -5:].tolist() - LOG.info( - f"GPU {rank}, Seq {i} after slicing: start={sliced_sample}, end={sliced_end}" - ) - - # Handle position_ids specially if present (important for packed sequences) - if "position_ids" in batch: - # For position_ids, we need to adjust them after slicing - # Each position_id should be relative to its slice - pos_ids = batch["position_ids"] - seq_len = pos_ids.shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - - # Find boundaries before slicing - if self.debug_level == "debug": - full_boundaries = find_sample_boundaries(pos_ids) - - # Check for boundaries that fall near slice edges - for i, boundaries in enumerate(full_boundaries): - problem_boundaries = check_for_boundary_splits( - boundaries, start_idx, end_idx + if key == "input_ids": + # Before slicing + non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item() + logger.info( + f"GPU {self.rank}: Total sequence length: {seq_len}, " + f"Non-padding tokens: {non_pad_tokens_total}" ) - if problem_boundaries: - LOG.warning( - f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}" - ) + logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") - batch["position_ids"] = pos_ids[:, start_idx:end_idx] - - # Find boundaries after slicing to verify correct transfer - if self.debug_level == "debug": - sliced_boundaries = find_sample_boundaries(batch["position_ids"]) - for i, boundaries in enumerate(sliced_boundaries): - LOG.info( - f"GPU {rank}: After slicing, sequence {i} has boundaries at {boundaries}" + # After slicing + non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item() + logger.info( + f"GPU {self.rank}: Slice {start_idx}-{end_idx}, " + f"Non-padding tokens in slice: {non_pad_tokens_slice}" ) - # Adjust position_ids to be relative to the start of this slice - # Only subtract if not the first GPU in the group - if rank > 0: - # Find boundaries between samples in the position_ids - # This preserves the sample packing structure - old_pos_ids = batch["position_ids"].clone() - batch["position_ids"] = adjust_position_ids_for_slice( - batch["position_ids"], start_idx - ) + dist.barrier() - if self.debug_level == "debug": - # Compare before and after adjustment - for i in range(min(2, old_pos_ids.shape[0])): - before = old_pos_ids[i, 0:10].tolist() - after = batch["position_ids"][i, 0:10].tolist() - LOG.info( - f"GPU {rank}, Seq {i} position_ids adjustment: before={before}, after={after}" - ) - - # Add gradient norm tracking for debugging - if self.debug_level == "debug": - # Attach hook to track gradient norms during backward pass - def hook_fn(grad): - norm = grad.norm().item() - LOG.info(f"GPU {rank}: Gradient norm = {norm:.4f}") - # Record any abnormally high gradients - if norm > 10.0: - LOG.warning(f"GPU {rank}: High gradient norm detected: {norm:.4f}") - return grad - - # Apply hook to input_ids embeddings if it goes through backward pass - if "input_ids" in batch and batch["input_ids"].requires_grad: - batch["input_ids"].register_hook(hook_fn) - - return batch - - -@dataclass -class V2SequenceParallelPackedDataCollator(V2BatchSamplerDataCollatorForSeq2Seq): - """ - Data collator for sequence parallelism with V2 sample packing. - """ - - debug_level: str = "debug" # Can be "debug" for more verbose output - - def __call__(self, features, return_tensors=None): - # Implementation similar to SequenceParallelPackedDataCollator with V2 base - # First, use the parent collator to handle sample packing and padding - batch = super().__call__(features, return_tensors=return_tensors) - - sp_group = get_ring_attn_group() - if sp_group is None: - return batch # Not using sequence parallelism - - # Get information about our position in the SP group - rank = dist.get_rank(group=sp_group) - world_size = dist.get_world_size(group=sp_group) - - # Enable debug level if requested - if self.debug_level == "debug": - original_shapes = { - k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() - } - LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}") - - if "position_ids" in batch: - # Find and log sample boundaries before slicing - boundaries = find_sample_boundaries(batch["position_ids"]) - for i, seq_boundaries in enumerate(boundaries): - LOG.info( - f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}" - ) - - # Process keys that need to be sliced - for key in ["input_ids", "attention_mask", "labels"]: - if key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - - if self.debug_level == "debug" and key == "input_ids": - # Log portions of the input to verify correct slicing - for i in range( - min(2, batch[key].shape[0]) - ): # Look at up to 2 sequences - # Sample the beginning, middle and end of the sequence before slicing - start_sample = batch[key][i, 0:5].tolist() - mid_sample = batch[key][ - i, seq_len // 2 : seq_len // 2 + 5 - ].tolist() - end_sample = batch[key][i, -5:].tolist() - LOG.info( - f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}" - ) - - batch[key] = batch[key][:, start_idx:end_idx] - - # Handle position_ids specially (same as in SequenceParallelPackedDataCollator) - if "position_ids" in batch: - # Implementation identical to the one in SequenceParallelPackedDataCollator - pos_ids = batch["position_ids"] - seq_len = pos_ids.shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - - # Find boundaries before slicing - if self.debug_level == "debug": - full_boundaries = find_sample_boundaries(pos_ids) - - # Check for boundaries that fall near slice edges - for i, boundaries in enumerate(full_boundaries): - problem_boundaries = check_for_boundary_splits( - boundaries, start_idx, end_idx - ) - if problem_boundaries: - LOG.warning( - f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}" - ) - - batch["position_ids"] = pos_ids[:, start_idx:end_idx] - - # Adjust position_ids to be relative to the start of this slice - if rank > 0: - batch["position_ids"] = adjust_position_ids_for_slice( - batch["position_ids"], start_idx - ) - - return batch - - -@dataclass -class SequenceParallelDataCollator(DataCollatorForSeq2Seq): - """ - Data collator for sequence parallelism without sample packing. - """ - - debug_level: str = "debug" # Can be "debug" for more verbose output - - def __call__(self, features, return_tensors=None): - # First, use the parent collator to pad everything correctly - batch = super().__call__(features, return_tensors=return_tensors) - - sp_group = get_ring_attn_group() - if sp_group is None: - return batch # Not using sequence parallelism - - # Get information about our position in the SP group - rank = dist.get_rank(group=sp_group) - world_size = dist.get_world_size(group=sp_group) - - if self.debug_level == "debug": - original_shapes = { - k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() - } - LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}") - - # Process keys that need to be sliced - for key in ["input_ids", "attention_mask", "labels"]: - if key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - - LOG.info( - f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})" - ) batch[key] = batch[key][:, start_idx:end_idx] # Handle position_ids if present if "position_ids" in batch: pos_ids = batch["position_ids"] seq_len = pos_ids.shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + slice_size = seq_len // self.world_size + start_idx = self.rank * slice_size + end_idx = ( + start_idx + slice_size if self.rank < self.world_size - 1 else seq_len + ) batch["position_ids"] = pos_ids[:, start_idx:end_idx] - # For non-packed sequences, we can simply subtract start_idx from all position_ids - if rank > 0: - batch["position_ids"] -= start_idx + # Adjust position_ids to be relative to the slice start + if self.rank > 0: + batch["position_ids"] = adjust_position_ids_for_slice( + batch["position_ids"], start_idx + ) return batch + + +@dataclass +class SequenceParallelPackedDataCollator( + SequenceParallelMixin, BatchSamplerDataCollatorForSeq2Seq +): + """ + Data collator for sequence parallelism with sample packing. Combines multiple + samples into a packed sequence, then slices it for each GPU. + """ + + def __call__(self, features, return_tensors=None): + # Use the parent collator to handle sample packing and padding + batch = super().__call__(features, return_tensors=return_tensors) + return self.apply_sequence_parallelism(batch) + + +@dataclass +class V2SequenceParallelPackedDataCollator( + SequenceParallelMixin, V2BatchSamplerDataCollatorForSeq2Seq +): + """ + Data collator for sequence parallelism with V2 sample packing. + """ + + def __call__(self, features, return_tensors=None): + # Use the parent collator to handle sample packing and padding + batch = super().__call__(features, return_tensors=return_tensors) + return self.apply_sequence_parallelism(batch) + + +@dataclass +class SequenceParallelDataCollator(SequenceParallelMixin, DataCollatorForSeq2Seq): + """ + Data collator for sequence parallelism without sample packing. + """ + + def __call__(self, features, return_tensors=None): + # Use the parent collator to pad everything correctly + batch = super().__call__(features, return_tensors=return_tensors) + return self.apply_sequence_parallelism(batch) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3fc971ca6..e88d063e6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger(__name__) + +MULTIMODEL_AUTO_MODEL_MAPPING = { + "llava": LlavaForConditionalGeneration, + "mllama": MllamaForConditionalGeneration, +} # copied from accelerator.FullyShardedDataParallelPlugin @@ -476,7 +481,7 @@ class ModelLoader: else: self.text_model_config = self.model_config - self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name + self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: # load any patches from plugins @@ -612,7 +617,7 @@ class ModelLoader: patch_self_attn_lora() - def patch_llama_derived_model(self) -> None: + def patch_llama_derived_model(self): """Modify all llama derived models in one block""" self.patch_loss_llama() @@ -662,25 +667,16 @@ class ModelLoader: "Shifted-sparse attention not currently implemented without flash attention." ) - def set_auto_model_loader(self) -> None: - """set self.AutoModelLoader - - default value: AutoModelForCausalLM (set at __init__) - - when using a multi modality model, self.AutoModelLoader should - be set according to model type of the model + def set_auto_model_loader(self): + """ + Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM` + (set at `__init__`). When using a multimodal model, `self.auto_model_loader` + should be set according to the type of the model. """ if self.cfg.is_multimodal: - if self.model_config.model_type == "llava": - self.AutoModelLoader = ( # pylint: disable=invalid-name - LlavaForConditionalGeneration - ) - elif self.model_config.model_type == "mllama": - self.AutoModelLoader = ( # pylint: disable=invalid-name - MllamaForConditionalGeneration - ) - else: - self.AutoModelLoader = ( - AutoModelForVision2Seq # pylint: disable=invalid-name - ) + self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get( + self.model_config.model_type, AutoModelForVision2Seq + ) def set_device_map_config(self) -> None: device_map = self.cfg.device_map @@ -704,7 +700,7 @@ class ModelLoader: from accelerate import infer_auto_device_map with init_empty_weights(): - model_canvas = self.AutoModelLoader.from_config( + model_canvas = self.auto_model_loader.from_config( self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, ) @@ -925,11 +921,27 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) + + # Load model with random initialization if specified + if self.cfg.random_init: + # AutoModel classes support the from_config method + if self.auto_model_loader in [ + AutoModelForCausalLM, + AutoModelForVision2Seq, + ]: + self.model = self.auto_model_loader.from_config( + config=self.model_config, + ) + else: + self.model = self.auto_model_loader( + config=self.model_config, + ) + else: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + **self.model_kwargs, + ) # TODO (MengqingCao) split these patches seperately if self.cfg.flash_attention and not self.inference: @@ -967,7 +979,7 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config if self.cfg.gptq: - self.model = self.AutoModelLoader.from_pretrained( + self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, @@ -1000,7 +1012,7 @@ class ModelLoader: if self.cfg.gptq: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( + self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, @@ -1020,7 +1032,7 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - self.model = self.AutoModelLoader.from_pretrained( + self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, trust_remote_code=self.cfg.trust_remote_code or False, @@ -1316,7 +1328,7 @@ def load_model( """ Load a model for a given configuration and tokenizer. """ - loader = ModelLoader( + model_loader = ModelLoader( cfg, tokenizer, processor=processor, @@ -1324,7 +1336,7 @@ def load_model( reference_model=reference_model, **kwargs, ) - return loader.load_model() + return model_loader.load_model() def load_adapter(model, cfg, adapter, inference=False): diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ca0d79a27..0e04039a0 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs + * cfg.sequence_parallel_size ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", @@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est - total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) + total_num_steps = int( + math.floor( + data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size + ) + ) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") @@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) else: total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + math.ceil( + len(train_dataset) + * cfg.num_epochs + * cfg.sequence_parallel_size + / cfg.batch_size + ) ) LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) return total_num_steps diff --git a/tests/e2e/patched/test_sequence_parallelism.py b/tests/e2e/patched/test_sequence_parallelism.py index 6d7c64305..a3483122b 100644 --- a/tests/e2e/patched/test_sequence_parallelism.py +++ b/tests/e2e/patched/test_sequence_parallelism.py @@ -14,7 +14,6 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): from axolotl.utils.collators.sequence_parallel import ( adjust_position_ids_for_slice, check_for_boundary_splits, - find_sample_boundaries, ) @@ -30,24 +29,6 @@ def partial_state(): class TestSequenceParallelHelpers: """Test helper functions used in sequence parallelism.""" - def test_find_sample_boundaries(self): - """Test detection of boundaries in position_ids.""" - # Create sample position_ids with multiple sequences - position_ids = torch.tensor( - [ - # First sequence with 2 samples (boundary at index 5) - [0, 1, 2, 3, 4, 0, 1, 2, 3], - # Second sequence with 3 samples (boundaries at 3 and 7) - [0, 1, 2, 0, 1, 2, 3, 0, 1], - ] - ) - - boundaries = find_sample_boundaries(position_ids) - - assert len(boundaries) == 2 - assert boundaries[0] == [5] # First sequence has boundary at index 5 - assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7 - def test_adjust_position_ids_for_slice(self, partial_state): """Test position_ids adjustment for sequence slices.""" # Create sample position_ids with multiple sequences