fix to remove attention_mask
This commit is contained in:
@@ -63,6 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
|
|
||||||
Subclass and override for custom behavior.
|
Subclass and override for custom behavior.
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
self.args.sample_packing
|
||||||
|
and hasattr(inputs, "attention_mask")
|
||||||
|
and hasattr(inputs, "position_ids")
|
||||||
|
):
|
||||||
|
del inputs["attention_mask"]
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
"""Helper KD utils"""
|
"""Helper KD utils"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import FloatTensor
|
from torch import FloatTensor, Tensor
|
||||||
|
|
||||||
|
|
||||||
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||||
@@ -37,3 +41,60 @@ def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
|||||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||||
|
|
||||||
return final_logprobs_tensor
|
return final_logprobs_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def strided_chunk_views(
|
||||||
|
tensor: Union[np.ndarray, torch.Tensor],
|
||||||
|
chunks: int,
|
||||||
|
dim: int = 0,
|
||||||
|
stride: int = 1,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
) -> List[Union[np.ndarray, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor (numpy array or torch tensor)
|
||||||
|
chunks: Number of chunks to create
|
||||||
|
dim: Dimension along which to chunk (default: 0)
|
||||||
|
stride: Stride between chunk starting positions (default: 1)
|
||||||
|
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tensor chunks (views when possible, copies when necessary)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the size of the specified dimension
|
||||||
|
dim_size = tensor.shape[dim]
|
||||||
|
|
||||||
|
# Calculate chunk size if not provided
|
||||||
|
if chunk_size is None:
|
||||||
|
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
|
||||||
|
|
||||||
|
chunks_list = []
|
||||||
|
|
||||||
|
for i in range(chunks):
|
||||||
|
start_idx = i * stride
|
||||||
|
end_idx = min(start_idx + chunk_size, dim_size)
|
||||||
|
|
||||||
|
# Break if we've gone beyond the tensor
|
||||||
|
if start_idx >= dim_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create slice objects for all dimensions
|
||||||
|
slices = [slice(None)] * tensor.ndim
|
||||||
|
slices[dim] = slice(start_idx, end_idx)
|
||||||
|
|
||||||
|
chunk = tensor[tuple(slices)]
|
||||||
|
chunks_list.append(chunk)
|
||||||
|
|
||||||
|
return chunks_list
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
|
||||||
|
dim_size = input_tensor.shape[dim]
|
||||||
|
stride = math.ceil(dim_size / chunks)
|
||||||
|
|
||||||
|
return strided_chunk_views(
|
||||||
|
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user