don't use triton for now

This commit is contained in:
Wing Lian
2024-12-21 16:47:11 -05:00
parent c0757e8a20
commit d8d817eaed

View File

@@ -100,18 +100,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
# Slice or gather student logits to match teacher seq len
# e.g.:
teacher_seq_len = target_token_ids.shape[1]
student_logits_for_kd = student_logits[:, :teacher_seq_len, :] # [B, seq_len, vocab_size]
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab_size]
# GATHER top-K from student
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids # same shape [B, seq_len, K]
student_logits_for_kd,
dim=-1,
index=target_token_ids, # same shape [B, seq_len, K]
)
# Now call the Triton-based KD loss
kd_sum = kd_loss_triton(
student_logits_topk,
target_logprobs, # teacher logprobs [B, seq_len, K]
target_mask, # mask [B, seq_len, K]
target_logprobs, # teacher logprobs [B, seq_len, K]
target_mask, # mask [B, seq_len, K]
)
# Normalize however you want
@@ -140,9 +144,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
return self.compute_loss_w_triton(
model, inputs, return_outputs, num_items_in_batch
)
# return self.compute_loss_w_triton(
# model, inputs, return_outputs, num_items_in_batch
# )
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")