fix kd loss so it's causal (fixes repeating tokens)
This commit is contained in:
@@ -164,6 +164,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
@@ -171,12 +173,19 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
student_logits = outputs["logits"]
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||
|
||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
||||
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
|
||||
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
|
||||
shift_target_mask = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
loss_kd = kd_loss_function(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
shift_logits,
|
||||
shift_target_token_ids,
|
||||
shift_target_logprobs,
|
||||
shift_target_mask,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user