don't use triton for now
This commit is contained in:
@@ -100,18 +100,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
# Slice or gather student logits to match teacher seq len
|
||||
# e.g.:
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :] # [B, seq_len, vocab_size]
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab_size]
|
||||
|
||||
# GATHER top-K from student
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids # same shape [B, seq_len, K]
|
||||
student_logits_for_kd,
|
||||
dim=-1,
|
||||
index=target_token_ids, # same shape [B, seq_len, K]
|
||||
)
|
||||
|
||||
# Now call the Triton-based KD loss
|
||||
kd_sum = kd_loss_triton(
|
||||
student_logits_topk,
|
||||
target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||
target_mask, # mask [B, seq_len, K]
|
||||
target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||
target_mask, # mask [B, seq_len, K]
|
||||
)
|
||||
|
||||
# Normalize however you want
|
||||
@@ -140,9 +144,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
return self.compute_loss_w_triton(
|
||||
model, inputs, return_outputs, num_items_in_batch
|
||||
)
|
||||
# return self.compute_loss_w_triton(
|
||||
# model, inputs, return_outputs, num_items_in_batch
|
||||
# )
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
|
||||
Reference in New Issue
Block a user