chunking not necessary
This commit is contained in:
@@ -465,18 +465,17 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
# Wrapper function for chunked computation
|
# Wrapper function for chunked computation
|
||||||
def kl_div_loss_chunked(
|
def loss(
|
||||||
student_logits,
|
student_logits: torch.Tensor,
|
||||||
target_token_ids,
|
target_token_ids: torch.Tensor,
|
||||||
target_logprobs,
|
target_logprobs: torch.Tensor,
|
||||||
target_mask,
|
target_mask: torch.Tensor,
|
||||||
num_items_in_batch=-1,
|
num_items_in_batch: int = -1,
|
||||||
kd_temperature=1.0,
|
kd_temperature: float = 1.0,
|
||||||
top_k_before_softmax=0,
|
top_k_before_softmax: int = 0,
|
||||||
n_chunks=1,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Memory-efficient KL divergence loss computation.
|
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
student_logits: Student logits [B, seq_len, vocab_size]
|
student_logits: Student logits [B, seq_len, vocab_size]
|
||||||
@@ -486,81 +485,15 @@ def kl_div_loss_chunked(
|
|||||||
num_items_in_batch: Number of items for normalization (-1 for auto)
|
num_items_in_batch: Number of items for normalization (-1 for auto)
|
||||||
kd_temperature: Temperature for KD
|
kd_temperature: Temperature for KD
|
||||||
top_k_before_softmax: Flag for softmax application order
|
top_k_before_softmax: Flag for softmax application order
|
||||||
n_chunks: Number of chunks to process (for memory efficiency)
|
|
||||||
"""
|
"""
|
||||||
batch_size = student_logits.shape[0]
|
total_loss = TopKKLDivergence.apply(
|
||||||
|
|
||||||
# If n_chunks <= 0, use the entire batch size
|
|
||||||
if n_chunks <= 0:
|
|
||||||
n_chunks = batch_size
|
|
||||||
|
|
||||||
# Determine the actual number of chunks to use (find largest factor <= n_chunks)
|
|
||||||
factors = [i for i in range(1, batch_size + 1) if batch_size % i == 0]
|
|
||||||
actual_chunks = factors[
|
|
||||||
min(
|
|
||||||
len(factors) - 1,
|
|
||||||
max(
|
|
||||||
0,
|
|
||||||
next(
|
|
||||||
(i for i, f in enumerate(factors) if f >= n_chunks),
|
|
||||||
len(factors) - 1,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compute chunk size
|
|
||||||
chunk_size = batch_size // actual_chunks
|
|
||||||
total_loss = 0.0
|
|
||||||
|
|
||||||
# Process in chunks
|
|
||||||
for i in range(0, batch_size, chunk_size):
|
|
||||||
chunk_end = min(i + chunk_size, batch_size)
|
|
||||||
chunk_loss = TopKKLDivergence.apply(
|
|
||||||
student_logits[i:chunk_end],
|
|
||||||
target_token_ids[i:chunk_end],
|
|
||||||
target_logprobs[i:chunk_end],
|
|
||||||
target_mask[i:chunk_end],
|
|
||||||
-1 if num_items_in_batch <= 0 else num_items_in_batch // actual_chunks,
|
|
||||||
kd_temperature,
|
|
||||||
top_k_before_softmax,
|
|
||||||
)
|
|
||||||
total_loss += chunk_loss
|
|
||||||
|
|
||||||
# Normalize by the number of chunks
|
|
||||||
return total_loss / actual_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def loss(
|
|
||||||
student_logits: torch.Tensor,
|
|
||||||
target_token_ids: torch.Tensor,
|
|
||||||
target_logprobs: torch.Tensor,
|
|
||||||
target_mask: torch.Tensor,
|
|
||||||
num_items_in_batch: int = -1,
|
|
||||||
kd_temperature: float = 1.0,
|
|
||||||
top_k_before_softmax: int = 0,
|
|
||||||
n_chunks: int = 1,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Triton-accelerated KL divergence loss for knowledge distillation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
student_logits: Student model logits [B, seq_len, vocab_size]
|
|
||||||
target_token_ids: Teacher's top-k token IDs [B, seq_len, top_k]
|
|
||||||
target_logprobs: Teacher's top-k logprobs [B, seq_len, top_k]
|
|
||||||
target_mask: Mask for valid tokens [B, seq_len, top_k]
|
|
||||||
num_items_in_batch: Number of items for normalization (-1 for auto)
|
|
||||||
kd_temperature: Temperature for KD
|
|
||||||
top_k_before_softmax: Flag for softmax application order
|
|
||||||
n_chunks: Number of chunks for memory efficiency
|
|
||||||
"""
|
|
||||||
return kl_div_loss_chunked(
|
|
||||||
student_logits,
|
student_logits,
|
||||||
target_token_ids,
|
target_token_ids,
|
||||||
target_logprobs,
|
target_logprobs,
|
||||||
target_mask,
|
target_mask,
|
||||||
num_items_in_batch,
|
-1 if num_items_in_batch <= 0 else num_items_in_batch,
|
||||||
kd_temperature,
|
kd_temperature,
|
||||||
top_k_before_softmax,
|
top_k_before_softmax,
|
||||||
n_chunks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user