fix kd loss so it's causal (fixes repeating tokens)
This commit is contained in:
@@ -164,6 +164,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
target_token_ids = inputs.pop("target_token_ids")
|
target_token_ids = inputs.pop("target_token_ids")
|
||||||
target_mask = inputs.pop("target_mask")
|
target_mask = inputs.pop("target_mask")
|
||||||
|
|
||||||
|
seq_len = target_token_ids.shape[1]
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
if num_items_in_batch is not None:
|
if num_items_in_batch is not None:
|
||||||
@@ -171,12 +173,19 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
inputs = {**inputs, **loss_kwargs}
|
inputs = {**inputs, **loss_kwargs}
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
student_logits = outputs["logits"]
|
# FIXME: account for tokenizer.padding_side
|
||||||
|
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||||
|
|
||||||
|
shift_logits = student_logits[..., :-1, :].contiguous()
|
||||||
|
shift_target_logprobs = target_logprobs[..., 1:, :].contiguous()
|
||||||
|
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
|
||||||
|
shift_target_mask = target_mask[..., 1:, :].contiguous()
|
||||||
|
|
||||||
loss_kd = kd_loss_function(
|
loss_kd = kd_loss_function(
|
||||||
student_logits,
|
shift_logits,
|
||||||
target_token_ids,
|
shift_target_token_ids,
|
||||||
target_logprobs,
|
shift_target_logprobs,
|
||||||
target_mask,
|
shift_target_mask,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
kd_temperature=self.args.kd_temperature,
|
kd_temperature=self.args.kd_temperature,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user