fix collator setup
This commit is contained in:
@@ -439,6 +439,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator_args = [self.tokenizer]
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing and use_batch_sampler_collator:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
@@ -468,16 +478,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator_args.pop(0)
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
loss for top_k KL divergence
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def zscore_standardize(
|
||||
@@ -235,3 +236,76 @@ def topk_kd_loss_with_zscore(
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
|
||||
|
||||
class ChunkedTopKKDLoss(nn.Module):
|
||||
"""
|
||||
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
||||
|
||||
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs.
|
||||
"""
|
||||
|
||||
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
||||
super().__init__()
|
||||
self.num_output_chunks = num_output_chunks
|
||||
self.kd_temperature = kd_temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
||||
target_mask: torch.Tensor, # [B, seq_len, K]
|
||||
num_items_in_batch: int = -1, # optional batch size for normalization
|
||||
) -> torch.Tensor:
|
||||
|
||||
# 1. Split along the "token" dimension (dim=1).
|
||||
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
|
||||
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
|
||||
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
|
||||
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
|
||||
|
||||
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
|
||||
# so that our final average is consistent with the entire sequence/batch.
|
||||
total_loss = 0.0
|
||||
total_valid_tokens = 0
|
||||
|
||||
# 2. Loop over each chunk and compute a chunk-specific loss.
|
||||
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
|
||||
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
|
||||
):
|
||||
# We pass num_items_in_batch=-1 so that the kd_loss
|
||||
# will average over *this chunk's* valid tokens only.
|
||||
chunk_loss = loss(
|
||||
student_logits=st_chunk,
|
||||
target_token_ids=tid_chunk,
|
||||
target_logprobs=lp_chunk,
|
||||
target_mask=msk_chunk,
|
||||
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
||||
kd_temperature=self.kd_temperature,
|
||||
)
|
||||
|
||||
# kd_loss returns an average over the chunk's valid tokens.
|
||||
# We want a global average in the end, so we need to re‐weight
|
||||
# by the number of valid tokens in this chunk and keep track of the total.
|
||||
chunk_valid_mask = msk_chunk.to(torch.bool)
|
||||
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
|
||||
|
||||
# Re-scale "chunk average" back to "chunk sum"
|
||||
chunk_loss_sum = chunk_loss * chunk_valid_count
|
||||
|
||||
total_loss += chunk_loss_sum
|
||||
total_valid_tokens += chunk_valid_count
|
||||
|
||||
# 3. Normalize *once* at the end.
|
||||
if num_items_in_batch > 0:
|
||||
# If the user gave us a manual denominator (e.g. total items in batch),
|
||||
# we divide by it. Typically used if each item is of different length.
|
||||
final_loss = total_loss / float(num_items_in_batch)
|
||||
else:
|
||||
# Otherwise, divide by total valid tokens across all chunks.
|
||||
# to get the same result as a non-chunked approach.
|
||||
final_loss = total_loss / float(total_valid_tokens)
|
||||
|
||||
return final_loss
|
||||
|
||||
@@ -18,8 +18,7 @@ KD trainer
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .topk_logprob.forward_kl import ChunkedTopKKDLoss, topk_kd_loss_with_zscore
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -27,6 +26,13 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.kd_loss_fn = ChunkedTopKKDLoss(
|
||||
num_output_chunks=8, kd_temperature=self.args.kd_temperature
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
@@ -85,14 +91,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
loss_kd = self.kd_loss_fn(
|
||||
# shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
# top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
|
||||
Reference in New Issue
Block a user