helper utils
This commit is contained in:
49
src/axolotl/utils/tensors.py
Normal file
49
src/axolotl/utils/tensors.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
|
||||||
|
split_tensors = []
|
||||||
|
|
||||||
|
counts = count_nonzero_sequences(cu_seqlens)
|
||||||
|
# Iterate over each batch
|
||||||
|
for i in range(tensor.size(0)):
|
||||||
|
seq_lens = cu_seqlens[i]
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
# Iterate over the cumulative sequence lengths
|
||||||
|
for j, end_idx in enumerate(seq_lens[1:]):
|
||||||
|
if end_idx == start_idx:
|
||||||
|
break
|
||||||
|
# Extract and pad the current sequence
|
||||||
|
current_seq = tensor[i, start_idx:end_idx]
|
||||||
|
keep = True
|
||||||
|
if keep_fn:
|
||||||
|
keep = keep_fn(current_seq, index=j, nonzero_total=counts[i])
|
||||||
|
if not keep:
|
||||||
|
continue
|
||||||
|
padding_size = max_seqlen - current_seq.size(0)
|
||||||
|
padded_seq = F.pad(current_seq, (0, 0) * (current_seq.dim() - 2) + (0, padding_size))
|
||||||
|
|
||||||
|
# Append the padded sequence to the list
|
||||||
|
split_tensors.append(padded_seq)
|
||||||
|
|
||||||
|
# Update start index for the next sequence
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
# Stack the padded tensors
|
||||||
|
return torch.stack(split_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def count_nonzero_sequences(cu_seqlens: torch.Tensor) -> torch.LongTensor:
|
||||||
|
diffs = torch.diff(cu_seqlens, dim=1, prepend=torch.zeros(cu_seqlens.shape[0], 1, dtype=cu_seqlens.dtype))
|
||||||
|
valid_lengths = diffs != 0
|
||||||
|
counts = valid_lengths.sum(dim=1).long()
|
||||||
|
|
||||||
|
return counts
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
# Example tensor with dimensions [batch_size, seq_len, other_dimensions...]
|
||||||
|
# example_tensor = torch.randn(batch_size, seq_len, other_dimensions...)
|
||||||
|
# cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
|
||||||
|
# split_padded_tensor = split_and_pad_packed(example_tensor, cu_seqlens, max_seqlen)
|
||||||
Reference in New Issue
Block a user