From 9fe36db215f96e05e58521724fa9eb6b3e722fda Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 14:16:33 -0500 Subject: [PATCH] remove duplicate code --- src/axolotl/core/trainers/kd/__init__.py | 110 ------------------ .../core/trainers/kd/topk_logprob/__init__.py | 0 .../trainers/kd/topk_logprob/forward_kl.py | 86 -------------- 3 files changed, 196 deletions(-) delete mode 100644 src/axolotl/core/trainers/kd/__init__.py delete mode 100644 src/axolotl/core/trainers/kd/topk_logprob/__init__.py delete mode 100644 src/axolotl/core/trainers/kd/topk_logprob/forward_kl.py diff --git a/src/axolotl/core/trainers/kd/__init__.py b/src/axolotl/core/trainers/kd/__init__.py deleted file mode 100644 index ad68055c9..000000000 --- a/src/axolotl/core/trainers/kd/__init__.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2024 Axolotl AI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -KD trainer -""" - -import torch - -from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss - - -class AxolotlKDTrainer(AxolotlTrainer): - """ - Custom trainer subclass for Knowledge Distillation (KD) - """ - - def _set_signature_columns_if_needed(self): - super()._set_signature_columns_if_needed() - columns_to_add = [] - if self._signature_columns: - if "target_logprobs" not in self._signature_columns: - columns_to_add.append("target_logprobs") - if "target_token_ids" not in self._signature_columns: - columns_to_add.append("target_token_ids") - if "target_mask" not in self._signature_columns: - columns_to_add.append("target_mask") - if columns_to_add: - self._signature_columns += columns_to_add - - def compute_loss( - self, - model, - inputs, - return_outputs=False, - num_items_in_batch=None, - shift_targets=False, - ): - """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - - Subclass and override for custom behavior. - """ - - target_logprobs = inputs.pop("target_logprobs") - target_token_ids = inputs.pop("target_token_ids") - target_mask = inputs.pop("target_mask") - - seq_len = target_token_ids.shape[1] - - if self.model_accepts_loss_kwargs: - loss_kwargs = {} - if num_items_in_batch is not None: - loss_kwargs["num_items_in_batch"] = num_items_in_batch - inputs = {**inputs, **loss_kwargs} - outputs = model(**inputs) - - # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, :seq_len, :].contiguous() - - if shift_targets: - shift_logits = student_logits[..., :-1, :].contiguous() - target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - target_mask_for_loss = target_mask[..., 1:, :].contiguous() - else: - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs.contiguous() - target_token_ids_for_loss = target_token_ids.contiguous() - target_mask_for_loss = target_mask.contiguous() - - loss_kd = topk_kd_loss( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - kd_temperature=self.args.kd_temperature, - ) - - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - else: - loss = loss_kd - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[ # pylint: disable=attribute-defined-outside-init - self.args.past_index - ] - - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - loss *= self.accelerator.num_processes - - torch.cuda.empty_cache() - - return (loss, outputs) if return_outputs else loss diff --git a/src/axolotl/core/trainers/kd/topk_logprob/__init__.py b/src/axolotl/core/trainers/kd/topk_logprob/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/axolotl/core/trainers/kd/topk_logprob/forward_kl.py b/src/axolotl/core/trainers/kd/topk_logprob/forward_kl.py deleted file mode 100644 index b050e066f..000000000 --- a/src/axolotl/core/trainers/kd/topk_logprob/forward_kl.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2024 Axolotl AI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -loss for top_k KL divergence -""" -from typing import Optional - -import torch - - -def loss( - student_logits, - target_token_ids, - target_logprobs, - target_mask, - num_items_in_batch: Optional[int] = None, - kd_temperature: float = 1.0, -): - # teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding - - # Determine the teacher sequence length - # _, teacher_seq_len, top_k = target_token_ids.shape - teacher_seq_len = target_token_ids.shape[1] - - # Slice student logits to match the teacher-provided sequence length - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, teacher_seq_len, vocab_size] - - # Gather student logits for teacher's top-K tokens - # shape -> [B, teacher_seq_len, K] - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) - - # Apply KD temperature to student’s logits: - # z_s(T) = z_s / T - if kd_temperature != 1.0: - student_logits_topk = student_logits_topk / kd_temperature - - # Convert student top-k logits to logprobs - student_logprobs_topk = student_logits_topk - torch.logsumexp( - student_logits_topk, dim=-1, keepdim=True - ) # [B, seq_len, K] - - # Convert teacher_mask to boolean for indexing - valid_mask = target_mask.bool() - - # Prune tensors to only keep valid tokens - # This will result in 1D arrays of only valid positions - student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens] - target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens] - - # Since teacher_logprobs are already normalized, just exponentiate to get probabilities - teacher_probs = target_logprobs.exp() - - # Compute forward KL: - # KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens. - kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) - kd_loss = kd_loss_per_token.sum() - - # 9) Multiply by T^2 (classical KD scaling) - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - - # Normalize by number of items or mean over valid tokens - if num_items_in_batch is not None: - # If you know how many items should be considered in the batch - kd_loss = kd_loss / num_items_in_batch - else: - # Otherwise, just average over all valid tokens - kd_loss = kd_loss / kd_loss_per_token.size(0) - - return kd_loss