fix kd loss so it's causal (fixes repeating tokens)

This commit is contained in:
Wing Lian
2024-12-25 18:59:30 -05:00
parent 53ec07d44c
commit f03fa703b7

View File

@@ -164,6 +164,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
@@ -171,12 +173,19 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
student_logits = outputs["logits"]
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
shift_logits = student_logits[..., :-1, :].contiguous()
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
shift_target_mask = target_mask[..., 1:, :].contiguous()
loss_kd = kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
shift_logits,
shift_target_token_ids,
shift_target_logprobs,
shift_target_mask,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)