don't use triton for now

This commit is contained in:
Wing Lian
2024-12-21 16:47:11 -05:00
parent c0757e8a20
commit d8d817eaed

View File

@@ -100,18 +100,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
# Slice or gather student logits to match teacher seq len # Slice or gather student logits to match teacher seq len
# e.g.: # e.g.:
teacher_seq_len = target_token_ids.shape[1] teacher_seq_len = target_token_ids.shape[1]
student_logits_for_kd = student_logits[:, :teacher_seq_len, :] # [B, seq_len, vocab_size] student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab_size]
# GATHER top-K from student # GATHER top-K from student
student_logits_topk = torch.gather( student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids # same shape [B, seq_len, K] student_logits_for_kd,
dim=-1,
index=target_token_ids, # same shape [B, seq_len, K]
) )
# Now call the Triton-based KD loss # Now call the Triton-based KD loss
kd_sum = kd_loss_triton( kd_sum = kd_loss_triton(
student_logits_topk, student_logits_topk,
target_logprobs, # teacher logprobs [B, seq_len, K] target_logprobs, # teacher logprobs [B, seq_len, K]
target_mask, # mask [B, seq_len, K] target_mask, # mask [B, seq_len, K]
) )
# Normalize however you want # Normalize however you want
@@ -140,9 +144,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior. Subclass and override for custom behavior.
""" """
return self.compute_loss_w_triton( # return self.compute_loss_w_triton(
model, inputs, return_outputs, num_items_in_batch # model, inputs, return_outputs, num_items_in_batch
) # )
target_logprobs = inputs.pop("target_logprobs") target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids") target_token_ids = inputs.pop("target_token_ids")