remove duplicate code
This commit is contained in:
@@ -1,110 +0,0 @@
|
|||||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
KD trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Custom trainer subclass for Knowledge Distillation (KD)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
|
||||||
super()._set_signature_columns_if_needed()
|
|
||||||
columns_to_add = []
|
|
||||||
if self._signature_columns:
|
|
||||||
if "target_logprobs" not in self._signature_columns:
|
|
||||||
columns_to_add.append("target_logprobs")
|
|
||||||
if "target_token_ids" not in self._signature_columns:
|
|
||||||
columns_to_add.append("target_token_ids")
|
|
||||||
if "target_mask" not in self._signature_columns:
|
|
||||||
columns_to_add.append("target_mask")
|
|
||||||
if columns_to_add:
|
|
||||||
self._signature_columns += columns_to_add
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False,
|
|
||||||
num_items_in_batch=None,
|
|
||||||
shift_targets=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
|
||||||
|
|
||||||
Subclass and override for custom behavior.
|
|
||||||
"""
|
|
||||||
|
|
||||||
target_logprobs = inputs.pop("target_logprobs")
|
|
||||||
target_token_ids = inputs.pop("target_token_ids")
|
|
||||||
target_mask = inputs.pop("target_mask")
|
|
||||||
|
|
||||||
seq_len = target_token_ids.shape[1]
|
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
|
||||||
loss_kwargs = {}
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
|
||||||
inputs = {**inputs, **loss_kwargs}
|
|
||||||
outputs = model(**inputs)
|
|
||||||
|
|
||||||
# FIXME: account for tokenizer.padding_side
|
|
||||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
|
||||||
|
|
||||||
if shift_targets:
|
|
||||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
|
||||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
|
||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = student_logits.contiguous()
|
|
||||||
target_logprobs_for_loss = target_logprobs.contiguous()
|
|
||||||
target_token_ids_for_loss = target_token_ids.contiguous()
|
|
||||||
target_mask_for_loss = target_mask.contiguous()
|
|
||||||
|
|
||||||
loss_kd = topk_kd_loss(
|
|
||||||
shift_logits,
|
|
||||||
target_token_ids_for_loss,
|
|
||||||
target_logprobs_for_loss,
|
|
||||||
target_mask_for_loss,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
kd_temperature=self.args.kd_temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
|
||||||
kd_alpha = self.args.kd_alpha
|
|
||||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
|
||||||
else:
|
|
||||||
loss = loss_kd
|
|
||||||
# Save past state if it exists
|
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.args.past_index
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
|
||||||
loss *= self.accelerator.num_processes
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
loss for top_k KL divergence
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def loss(
|
|
||||||
student_logits,
|
|
||||||
target_token_ids,
|
|
||||||
target_logprobs,
|
|
||||||
target_mask,
|
|
||||||
num_items_in_batch: Optional[int] = None,
|
|
||||||
kd_temperature: float = 1.0,
|
|
||||||
):
|
|
||||||
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
|
|
||||||
|
|
||||||
# Determine the teacher sequence length
|
|
||||||
# _, teacher_seq_len, top_k = target_token_ids.shape
|
|
||||||
teacher_seq_len = target_token_ids.shape[1]
|
|
||||||
|
|
||||||
# Slice student logits to match the teacher-provided sequence length
|
|
||||||
student_logits_for_kd = student_logits[
|
|
||||||
:, :teacher_seq_len, :
|
|
||||||
] # [B, teacher_seq_len, vocab_size]
|
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
|
||||||
# shape -> [B, teacher_seq_len, K]
|
|
||||||
student_logits_topk = torch.gather(
|
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply KD temperature to student’s logits:
|
|
||||||
# z_s(T) = z_s / T
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
student_logits_topk = student_logits_topk / kd_temperature
|
|
||||||
|
|
||||||
# Convert student top-k logits to logprobs
|
|
||||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
|
||||||
student_logits_topk, dim=-1, keepdim=True
|
|
||||||
) # [B, seq_len, K]
|
|
||||||
|
|
||||||
# Convert teacher_mask to boolean for indexing
|
|
||||||
valid_mask = target_mask.bool()
|
|
||||||
|
|
||||||
# Prune tensors to only keep valid tokens
|
|
||||||
# This will result in 1D arrays of only valid positions
|
|
||||||
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
|
|
||||||
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
|
|
||||||
|
|
||||||
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
|
|
||||||
teacher_probs = target_logprobs.exp()
|
|
||||||
|
|
||||||
# Compute forward KL:
|
|
||||||
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
|
|
||||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
|
||||||
kd_loss = kd_loss_per_token.sum()
|
|
||||||
|
|
||||||
# 9) Multiply by T^2 (classical KD scaling)
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
|
||||||
|
|
||||||
# Normalize by number of items or mean over valid tokens
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
# If you know how many items should be considered in the batch
|
|
||||||
kd_loss = kd_loss / num_items_in_batch
|
|
||||||
else:
|
|
||||||
# Otherwise, just average over all valid tokens
|
|
||||||
kd_loss = kd_loss / kd_loss_per_token.size(0)
|
|
||||||
|
|
||||||
return kd_loss
|
|
||||||
Reference in New Issue
Block a user