fix to remove attention_mask

This commit is contained in:
Wing Lian
2025-06-01 08:15:56 -04:00
parent a8e2bddd19
commit 2302b14a84
2 changed files with 68 additions and 1 deletions

View File

@@ -63,6 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior. Subclass and override for custom behavior.
""" """
if (
self.args.sample_packing
and hasattr(inputs, "attention_mask")
and hasattr(inputs, "position_ids")
):
del inputs["attention_mask"]
if self.model_accepts_loss_kwargs: if self.model_accepts_loss_kwargs:
loss_kwargs = {} loss_kwargs = {}

View File

@@ -1,7 +1,11 @@
"""Helper KD utils""" """Helper KD utils"""
import math
from typing import List, Union
import numpy as np
import torch import torch
from torch import FloatTensor from torch import FloatTensor, Tensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
@@ -37,3 +41,60 @@ def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
final_logprobs_tensor = torch.log(teacher_probs_t_online) final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor return final_logprobs_tensor
def strided_chunk_views(
tensor: Union[np.ndarray, torch.Tensor],
chunks: int,
dim: int = 0,
stride: int = 1,
chunk_size: int | None = None,
) -> List[Union[np.ndarray, torch.Tensor]]:
"""
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
Args:
tensor: Input tensor (numpy array or torch tensor)
chunks: Number of chunks to create
dim: Dimension along which to chunk (default: 0)
stride: Stride between chunk starting positions (default: 1)
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
Returns:
List of tensor chunks (views when possible, copies when necessary)
"""
# Get the size of the specified dimension
dim_size = tensor.shape[dim]
# Calculate chunk size if not provided
if chunk_size is None:
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
chunks_list = []
for i in range(chunks):
start_idx = i * stride
end_idx = min(start_idx + chunk_size, dim_size)
# Break if we've gone beyond the tensor
if start_idx >= dim_size:
break
# Create slice objects for all dimensions
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start_idx, end_idx)
chunk = tensor[tuple(slices)]
chunks_list.append(chunk)
return chunks_list
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
dim_size = input_tensor.shape[dim]
stride = math.ceil(dim_size / chunks)
return strided_chunk_views(
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
)