diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 91dae14c3..eb84eff22 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -63,6 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ + if ( + self.args.sample_packing + and hasattr(inputs, "attention_mask") + and hasattr(inputs, "position_ids") + ): + del inputs["attention_mask"] if self.model_accepts_loss_kwargs: loss_kwargs = {} diff --git a/src/axolotl/integrations/kd/utils.py b/src/axolotl/integrations/kd/utils.py index 7a3633596..ba60694a5 100644 --- a/src/axolotl/integrations/kd/utils.py +++ b/src/axolotl/integrations/kd/utils.py @@ -1,7 +1,11 @@ """Helper KD utils""" +import math +from typing import List, Union + +import numpy as np import torch -from torch import FloatTensor +from torch import FloatTensor, Tensor def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: @@ -37,3 +41,60 @@ def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: final_logprobs_tensor = torch.log(teacher_probs_t_online) return final_logprobs_tensor + + +def strided_chunk_views( + tensor: Union[np.ndarray, torch.Tensor], + chunks: int, + dim: int = 0, + stride: int = 1, + chunk_size: int | None = None, +) -> List[Union[np.ndarray, torch.Tensor]]: + """ + Split a tensor into chunks along a dimension with striding, prioritizing views over copies. + + Args: + tensor: Input tensor (numpy array or torch tensor) + chunks: Number of chunks to create + dim: Dimension along which to chunk (default: 0) + stride: Stride between chunk starting positions (default: 1) + chunk_size: Size of each chunk. If None, calculated automatically (default: None) + + Returns: + List of tensor chunks (views when possible, copies when necessary) + """ + + # Get the size of the specified dimension + dim_size = tensor.shape[dim] + + # Calculate chunk size if not provided + if chunk_size is None: + chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division + + chunks_list = [] + + for i in range(chunks): + start_idx = i * stride + end_idx = min(start_idx + chunk_size, dim_size) + + # Break if we've gone beyond the tensor + if start_idx >= dim_size: + break + + # Create slice objects for all dimensions + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start_idx, end_idx) + + chunk = tensor[tuple(slices)] + chunks_list.append(chunk) + + return chunks_list + + +def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1): + dim_size = input_tensor.shape[dim] + stride = math.ceil(dim_size / chunks) + + return strided_chunk_views( + input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap + )