diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 53c039479..0a7c79434 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -200,8 +200,9 @@ class DataCollatorForSeq2Seq: Returns: Sliced batch dictionary. """ - # Process keys that need to be sliced - for key in ["input_ids", "attention_mask", "labels"]: + keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] + + for key in keys_to_slice: if key in batch: seq_len = batch[key].shape[1] slice_size = seq_len // self.local_world_size @@ -211,90 +212,11 @@ class DataCollatorForSeq2Seq: if self.local_rank < self.local_world_size - 1 else seq_len ) - - if key == "input_ids": - # Before slicing - non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item() - logger.info( - f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " - f"Total sequence length: {seq_len}, " - f"Non-padding tokens: {non_pad_tokens_total}" - ) - logger.info( - f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " - f"GPU {self.rank} token IDs: {batch['input_ids']}" - ) - logger.info( - f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: " - f"Slicing {key} from {seq_len} tokens to " - f"indices {start_idx}:{end_idx}" - ) - batch[key] = batch[key][:, start_idx:end_idx] - # if key == "input_ids": - # # After slicing - # non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item() - # logger.info( - # f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " - # f"Slice {start_idx}:{end_idx}, " - # f"Non-padding tokens in slice: {non_pad_tokens_slice}" - # ) - # logger.info( - # f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " - # f"GPU {self.rank} token IDs: {batch['input_ids']}" - # ) - - # if key == "labels": - # min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100 - # max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100 - # logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}") - - # # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index) - # invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100) - - # if invalid_mask.any(): - # # Log this for debugging - # num_invalid = invalid_mask.sum().item() - # logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100") - - # # Replace invalid labels with -100 (ignore index) - # batch["labels"][invalid_mask] = -100 - - if key == "attention_mask": - logger.info( - f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " - f"Attention mask: {batch['attention_mask']}, " - ) - - # Handle position_ids if present - if "position_ids" in batch: - pos_ids = batch["position_ids"] - seq_len = pos_ids.shape[1] - slice_size = seq_len // self.local_world_size - start_idx = self.local_rank * slice_size - end_idx = ( - start_idx + slice_size - if self.local_rank < self.local_world_size - 1 - else seq_len - ) - - batch["position_ids"] = pos_ids[:, start_idx:end_idx] - - # Adjust position_ids to be relative to the slice start - if self.local_rank > 0: - batch["position_ids"] = adjust_position_ids_for_slice( - batch["position_ids"], start_idx - ) - - logger.info( - f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " - f"Position IDs: {batch['position_ids']}, " - ) - - # if dist.get_rank() == 0: - # import ipdb; ipdb.set_trace() - # dist.barrier() + # Special handling for position_ids + if key == "position_ids" and self.local_rank > 0: + batch[key] = adjust_position_ids_for_slice(batch[key], start_idx) return batch