working on masking fix
This commit is contained in:
@@ -22,6 +22,60 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cu_seqlens(position_ids: torch.Tensor, total_seq_len: int) -> torch.Tensor:
|
||||||
|
# Must be batch size 1
|
||||||
|
position_ids = position_ids.flatten()
|
||||||
|
LOG.info(f"position_ids: {position_ids}")
|
||||||
|
|
||||||
|
# Find where the position resets to 0 (indicating a new sequence)
|
||||||
|
# We add position_ids.new_ones(1) to mark the start of the first sequence
|
||||||
|
sequence_starts = torch.cat([position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)])
|
||||||
|
|
||||||
|
# Get all indices where sequence_starts
|
||||||
|
potential_indices = torch.nonzero(sequence_starts).flatten()
|
||||||
|
|
||||||
|
# Filter out indices where the next index also has a zero
|
||||||
|
valid_indices = []
|
||||||
|
for i in range(len(potential_indices)):
|
||||||
|
# Get current index position in the original tensor
|
||||||
|
current_pos = potential_indices[i]
|
||||||
|
|
||||||
|
# Check if this is the last index or if the next element is not a zero
|
||||||
|
if i == len(potential_indices) - 1:
|
||||||
|
continue
|
||||||
|
elif potential_indices[i + 1] != current_pos + 1:
|
||||||
|
valid_indices.append(current_pos)
|
||||||
|
|
||||||
|
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
|
||||||
|
LOG.info(f"start_indices: {start_indices}")
|
||||||
|
|
||||||
|
# Calculate individual sequence lengths
|
||||||
|
if len(start_indices) > 1:
|
||||||
|
sequence_lengths = torch.diff(start_indices, append=torch.tensor([len(position_ids)]))
|
||||||
|
else:
|
||||||
|
sequence_lengths = torch.tensor([len(position_ids)])
|
||||||
|
|
||||||
|
LOG.info(f"sequence_lengths: {sequence_lengths}")
|
||||||
|
|
||||||
|
# Calculate cumulative sequence lengths
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
sequence_lengths.to(torch.cuda.current_device()),
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
LOG.info(f"cu_seqlens: {cu_seqlens}")
|
||||||
|
|
||||||
|
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||||
|
LOG.info(f"cu_seqlens with padding: {cu_seqlens}")
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
if dist.get_rank() == 1:
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
return cu_seqlens
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelMixin:
|
class SequenceParallelMixin:
|
||||||
"""
|
"""
|
||||||
Mixin class for sequence parallelism support in trainers.
|
Mixin class for sequence parallelism support in trainers.
|
||||||
@@ -118,17 +172,42 @@ class SequenceParallelMixin:
|
|||||||
# Calculate the full sequence length across all GPUs in this SP group
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
cu_seqlens = torch.cumsum(
|
# cu_seqlens = torch.cumsum(
|
||||||
torch.tensor(
|
# torch.tensor(
|
||||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
# packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
),
|
# ),
|
||||||
dim=-1,
|
# dim=-1,
|
||||||
dtype=torch.int32,
|
# dtype=torch.int32,
|
||||||
)
|
# )
|
||||||
cu_seqlens = F.pad(
|
# cu_seqlens = F.pad(
|
||||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
# F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
# packed_seq_lens = []
|
||||||
|
# current_len = 1 # Start counting the first token
|
||||||
|
|
||||||
|
# # Iterate through position IDs starting from the second element
|
||||||
|
# for i in range(1, len(inputs["position_ids"])):
|
||||||
|
# # If current position is less than previous, it's a new sequence
|
||||||
|
# if inputs["position_ids"][i] < inputs["position_ids"][i - 1]:
|
||||||
|
# packed_seq_lens.append(current_len)
|
||||||
|
# current_len = 1
|
||||||
|
# else:
|
||||||
|
# current_len += 1
|
||||||
|
|
||||||
|
# # Add the last sequence length
|
||||||
|
# packed_seq_lens.append(current_len)
|
||||||
|
# LOG.info(f"{packed_seq_lens}: packed_seq_lens")
|
||||||
|
|
||||||
|
# cu_seqlens = torch.cumsum(
|
||||||
|
# torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
||||||
|
# dim=-1,
|
||||||
|
# dtype=torch.int32,
|
||||||
|
# )
|
||||||
|
# cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||||
|
# LOG.info(f"{cu_seqlens}: cu_seqlens")
|
||||||
|
|
||||||
|
cu_seqlens = calculate_cu_seqlens(inputs["position_ids"], total_seq_len)
|
||||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
|
|||||||
@@ -211,8 +211,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
# Special handling for position_ids
|
# Special handling for position_ids
|
||||||
if key == "position_ids" and self.local_rank > 0:
|
# if key == "position_ids" and self.local_rank > 0:
|
||||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
# batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|||||||
@@ -1155,6 +1155,12 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not info.data["micro_batch_size"] == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"micro_batch_size must be set to 1 "
|
||||||
|
"due to a `ring-flash-attn` requirement"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
|
|||||||
Reference in New Issue
Block a user