From 4188700b7b8700dc679c782c7f8b0c3b6456e614 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 4 Apr 2025 20:24:18 +0000 Subject: [PATCH] working on masking fix --- .../core/trainers/mixins/sequence_parallel.py | 99 +++++++++++++++++-- src/axolotl/utils/collators/batching.py | 4 +- src/axolotl/utils/schemas/config.py | 6 ++ 3 files changed, 97 insertions(+), 12 deletions(-) diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 9bcd5db57..0709e2620 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -22,6 +22,60 @@ except ImportError: pass +def calculate_cu_seqlens(position_ids: torch.Tensor, total_seq_len: int) -> torch.Tensor: + # Must be batch size 1 + position_ids = position_ids.flatten() + LOG.info(f"position_ids: {position_ids}") + + # Find where the position resets to 0 (indicating a new sequence) + # We add position_ids.new_ones(1) to mark the start of the first sequence + sequence_starts = torch.cat([position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)]) + + # Get all indices where sequence_starts + potential_indices = torch.nonzero(sequence_starts).flatten() + + # Filter out indices where the next index also has a zero + valid_indices = [] + for i in range(len(potential_indices)): + # Get current index position in the original tensor + current_pos = potential_indices[i] + + # Check if this is the last index or if the next element is not a zero + if i == len(potential_indices) - 1: + continue + elif potential_indices[i + 1] != current_pos + 1: + valid_indices.append(current_pos) + + start_indices = torch.tensor(valid_indices, device=potential_indices.device) + LOG.info(f"start_indices: {start_indices}") + + # Calculate individual sequence lengths + if len(start_indices) > 1: + sequence_lengths = torch.diff(start_indices, append=torch.tensor([len(position_ids)])) + else: + sequence_lengths = torch.tensor([len(position_ids)]) + + LOG.info(f"sequence_lengths: {sequence_lengths}") + + # Calculate cumulative sequence lengths + cu_seqlens = torch.cumsum( + sequence_lengths.to(torch.cuda.current_device()), + dim=0, + dtype=torch.int32, + ) + LOG.info(f"cu_seqlens: {cu_seqlens}") + + cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) + LOG.info(f"cu_seqlens with padding: {cu_seqlens}") + + import torch.distributed as dist + if dist.get_rank() == 1: + import ipdb; ipdb.set_trace() + dist.barrier() + + return cu_seqlens + + class SequenceParallelMixin: """ Mixin class for sequence parallelism support in trainers. @@ -118,17 +172,42 @@ class SequenceParallelMixin: # Calculate the full sequence length across all GPUs in this SP group total_seq_len = seq_len * self.args.sequence_parallel_degree - cu_seqlens = torch.cumsum( - torch.tensor( - packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 - ), - dim=-1, - dtype=torch.int32, - ) - cu_seqlens = F.pad( - F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len - ) + # cu_seqlens = torch.cumsum( + # torch.tensor( + # packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 + # ), + # dim=-1, + # dtype=torch.int32, + # ) + # cu_seqlens = F.pad( + # F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len + # ) + # packed_seq_lens = [] + # current_len = 1 # Start counting the first token + + # # Iterate through position IDs starting from the second element + # for i in range(1, len(inputs["position_ids"])): + # # If current position is less than previous, it's a new sequence + # if inputs["position_ids"][i] < inputs["position_ids"][i - 1]: + # packed_seq_lens.append(current_len) + # current_len = 1 + # else: + # current_len += 1 + + # # Add the last sequence length + # packed_seq_lens.append(current_len) + # LOG.info(f"{packed_seq_lens}: packed_seq_lens") + + # cu_seqlens = torch.cumsum( + # torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), + # dim=-1, + # dtype=torch.int32, + # ) + # cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) + # LOG.info(f"{cu_seqlens}: cu_seqlens") + + cu_seqlens = calculate_cu_seqlens(inputs["position_ids"], total_seq_len) update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) def training_step( diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 33bb4b4cc..df0876021 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -211,8 +211,8 @@ class DataCollatorForSeq2Seq: batch[key] = batch[key][:, start_idx:end_idx] # Special handling for position_ids - if key == "position_ids" and self.local_rank > 0: - batch[key] = adjust_position_ids_for_slice(batch[key], start_idx) + # if key == "position_ids" and self.local_rank > 0: + # batch[key] = adjust_position_ids_for_slice(batch[key], start_idx) return batch diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0f9a3a1f9..b044b280e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1155,6 +1155,12 @@ class AxolotlInputConfig( raise ValueError( "flash_attention: true must be set with sequence_parallel_degree > 1" ) + + if not info.data["micro_batch_size"] == 1: + raise ValueError( + "micro_batch_size must be set to 1 " + "due to a `ring-flash-attn` requirement" + ) try: import ring_flash_attn # noqa: F401 # pylint:disable=unused-import