don't use triton for now
This commit is contained in:
@@ -100,18 +100,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
# Slice or gather student logits to match teacher seq len
|
# Slice or gather student logits to match teacher seq len
|
||||||
# e.g.:
|
# e.g.:
|
||||||
teacher_seq_len = target_token_ids.shape[1]
|
teacher_seq_len = target_token_ids.shape[1]
|
||||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :] # [B, seq_len, vocab_size]
|
student_logits_for_kd = student_logits[
|
||||||
|
:, :teacher_seq_len, :
|
||||||
|
] # [B, seq_len, vocab_size]
|
||||||
|
|
||||||
# GATHER top-K from student
|
# GATHER top-K from student
|
||||||
student_logits_topk = torch.gather(
|
student_logits_topk = torch.gather(
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids # same shape [B, seq_len, K]
|
student_logits_for_kd,
|
||||||
|
dim=-1,
|
||||||
|
index=target_token_ids, # same shape [B, seq_len, K]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now call the Triton-based KD loss
|
# Now call the Triton-based KD loss
|
||||||
kd_sum = kd_loss_triton(
|
kd_sum = kd_loss_triton(
|
||||||
student_logits_topk,
|
student_logits_topk,
|
||||||
target_logprobs, # teacher logprobs [B, seq_len, K]
|
target_logprobs, # teacher logprobs [B, seq_len, K]
|
||||||
target_mask, # mask [B, seq_len, K]
|
target_mask, # mask [B, seq_len, K]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalize however you want
|
# Normalize however you want
|
||||||
@@ -140,9 +144,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
Subclass and override for custom behavior.
|
Subclass and override for custom behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.compute_loss_w_triton(
|
# return self.compute_loss_w_triton(
|
||||||
model, inputs, return_outputs, num_items_in_batch
|
# model, inputs, return_outputs, num_items_in_batch
|
||||||
)
|
# )
|
||||||
|
|
||||||
target_logprobs = inputs.pop("target_logprobs")
|
target_logprobs = inputs.pop("target_logprobs")
|
||||||
target_token_ids = inputs.pop("target_token_ids")
|
target_token_ids = inputs.pop("target_token_ids")
|
||||||
|
|||||||
Reference in New Issue
Block a user