fix to remove attention_mask
This commit is contained in:
@@ -63,6 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and hasattr(inputs, "attention_mask")
|
||||
and hasattr(inputs, "position_ids")
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""Helper KD utils"""
|
||||
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import FloatTensor
|
||||
from torch import FloatTensor, Tensor
|
||||
|
||||
|
||||
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||
@@ -37,3 +41,60 @@ def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
|
||||
return final_logprobs_tensor
|
||||
|
||||
|
||||
def strided_chunk_views(
|
||||
tensor: Union[np.ndarray, torch.Tensor],
|
||||
chunks: int,
|
||||
dim: int = 0,
|
||||
stride: int = 1,
|
||||
chunk_size: int | None = None,
|
||||
) -> List[Union[np.ndarray, torch.Tensor]]:
|
||||
"""
|
||||
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor (numpy array or torch tensor)
|
||||
chunks: Number of chunks to create
|
||||
dim: Dimension along which to chunk (default: 0)
|
||||
stride: Stride between chunk starting positions (default: 1)
|
||||
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
|
||||
|
||||
Returns:
|
||||
List of tensor chunks (views when possible, copies when necessary)
|
||||
"""
|
||||
|
||||
# Get the size of the specified dimension
|
||||
dim_size = tensor.shape[dim]
|
||||
|
||||
# Calculate chunk size if not provided
|
||||
if chunk_size is None:
|
||||
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
|
||||
|
||||
chunks_list = []
|
||||
|
||||
for i in range(chunks):
|
||||
start_idx = i * stride
|
||||
end_idx = min(start_idx + chunk_size, dim_size)
|
||||
|
||||
# Break if we've gone beyond the tensor
|
||||
if start_idx >= dim_size:
|
||||
break
|
||||
|
||||
# Create slice objects for all dimensions
|
||||
slices = [slice(None)] * tensor.ndim
|
||||
slices[dim] = slice(start_idx, end_idx)
|
||||
|
||||
chunk = tensor[tuple(slices)]
|
||||
chunks_list.append(chunk)
|
||||
|
||||
return chunks_list
|
||||
|
||||
|
||||
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
|
||||
dim_size = input_tensor.shape[dim]
|
||||
stride = math.ceil(dim_size / chunks)
|
||||
|
||||
return strided_chunk_views(
|
||||
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user