From 11caf52529a80d6509eaa968cce7199585c1bc78 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Dec 2024 15:16:45 -0500 Subject: [PATCH] KD dataset loading and KD with logprobs --- src/axolotl/core/trainer_builder.py | 15 ++- src/axolotl/core/trainers/base.py | 1 + src/axolotl/core/trainers/kd.py | 112 ++++++++++++++++++ .../prompt_strategies/chat_template.py | 90 +++++++++++++- .../config/models/input/v0_4_1/__init__.py | 3 + 5 files changed, 214 insertions(+), 7 deletions(-) create mode 100644 src/axolotl/core/trainers/kd.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index dfc750c44..828929dc7 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -16,7 +16,11 @@ from typing import List, Type, Union import torch import transformers -from transformers import EarlyStoppingCallback, TrainerCallback, DataCollatorWithFlattening +from transformers import ( + DataCollatorWithFlattening, + EarlyStoppingCallback, + TrainerCallback, +) from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.trainers.base import ( @@ -29,6 +33,7 @@ from axolotl.core.trainers.base import ( AxolotlTrainer, ReLoRATrainer, ) +from axolotl.core.trainers.kd import AxolotlKDTrainer from axolotl.core.training_args import ( AxolotlCPOConfig, AxolotlDPOConfig, @@ -38,7 +43,6 @@ from axolotl.core.training_args import ( AxolotlTrainingArguments, ) from axolotl.integrations.base import PluginManager -from axolotl.integrations.liger.trainer.dpo_trainer import AxolotlLigerDPOTrainer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.utils import is_comet_available, is_mlflow_available @@ -282,6 +286,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlMambaTrainer if self.cfg.reward_model: return AxolotlRewardTrainer + if self.cfg.trainer == "kd": + return AxolotlKDTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -988,10 +994,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo"]: - if self.cfg.liger_pref_rl: - trainer_cls = AxolotlLigerDPOTrainer - else: - trainer_cls = AxolotlDPOTrainer + trainer_cls = AxolotlDPOTrainer trainer_cls_args = [self.model, self.model_ref] elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8014d67a7..d9b40c1c6 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -176,6 +176,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator self.dataset_tags = dataset_tags + self._signature_columns = None # workaround for pylint super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py new file mode 100644 index 000000000..e84036079 --- /dev/null +++ b/src/axolotl/core/trainers/kd.py @@ -0,0 +1,112 @@ +""" +KD trainer +""" + +from typing import Optional + +import torch + +from axolotl.core.trainers.base import AxolotlTrainer + + +def kd_loss_function( + student_logits, + target_token_ids, + target_logprobs, + num_items_in_batch: Optional[int] = None, + **kwargs, # pylint: disable=unused-argument +): + # student_logits: [B, seq_len, vocab_size] from the student's forward pass + # target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher + # target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens + + teacher_seq_len = target_token_ids.shape[1] + + # Slice the student logits to match the teacher-provided seq length + student_logits_for_kd = student_logits[ + :, -teacher_seq_len:, : + ] # Now [B, teacher_seq_len, vocab_size] + + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] + + # Convert student top-K logits to logprobs + student_logprobs_topk = student_logits_topk - torch.logsumexp( + student_logits_topk, dim=-1, keepdim=True + ) + + # teacher_probs are simply exp of teacher_logprobs (already scaled) + teacher_probs = target_logprobs.exp() + + # Compute forward KL + # L_kl = sum_k p^T_k (log p^T_k - log p^S_k) + kd_loss_per_position = ( + teacher_probs * (target_logprobs - student_logprobs_topk) + ).sum( + dim=-1 + ) # [B, teacher_seq_len] + + # gradient accumulation fixes + if num_items_in_batch: + kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar + else: + kd_loss = kd_loss_per_position.mean() # Scalar + + return kd_loss + + +class AxolotlKDTrainer(AxolotlTrainer): + """ + Custom trainer subclass for Knowledge Distillation (KD) + """ + + def _set_signature_columns_if_needed(self): + super()._set_signature_columns_if_needed() + columns_to_add = [] + if self._signature_columns: + if "target_logprobs" not in self._signature_columns: + columns_to_add.append("target_logprobs") + if "target_token_ids" not in self._signature_columns: + columns_to_add.append("target_token_ids") + if columns_to_add: + self._signature_columns += columns_to_add + + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + target_logprobs = inputs.pop("target_logprobs") + target_token_ids = inputs.pop("target_token_ids") + + if self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} + outputs = model(**inputs) + + student_logits = outputs["logits"] + loss_kd = kd_loss_function( + student_logits, + target_token_ids, + target_logprobs, + num_items_in_batch=num_items_in_batch, + ) + + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[ # pylint: disable=attribute-defined-outside-init + self.args.past_index + ] + + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss_kd *= self.accelerator.num_processes + + return (loss_kd, outputs) if return_outputs else loss_kd diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5b12130d7..c2a1060aa 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -5,6 +5,7 @@ HF Chat Templates prompt strategy import logging from typing import Any, Dict, List, Optional +import torch from transformers import ProcessorMixin from axolotl.prompt_tokenizers import PromptTokenizingStrategy @@ -459,6 +460,84 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return prompt.get(self.images, None) +class ChatTemplateStrategyWithKD(ChatTemplateStrategy): + """ + Handle fields for logprob KD + """ + + def __init__( + self, + prompter, + tokenizer, + train_on_inputs, + sequence_len, + roles_to_train=None, + train_on_eos=None, + logprobs_field="logprobs", + temperature=1.0, + ): + self.logprobs_field = logprobs_field + self.temperature = temperature + super().__init__( + prompter, + tokenizer, + train_on_inputs, + sequence_len, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + ) + + def transform_logprobs(self, sample): + logprobs = sample.pop(self.logprobs_field) + target_logprobs = [] + target_token_ids = [] + + for _, token_pos_logprobs in enumerate(logprobs): + # Initialize collections for logprobs and token_ids + position_logprobs = [] + position_token_ids = [] + + # Process each token probability entry + for entry in token_pos_logprobs: + # Extract logprob value + logprob = entry["logprob"] + + # Parse token_id from the "token_id:###" format + token_id = int(entry["token"].split(":")[1]) + + # Append to our collections + position_logprobs.append(logprob) + position_token_ids.append(token_id) + + # Convert to a tensor for easier manipulation + # Convert to tensor + position_logprobs_tensor = torch.tensor( + position_logprobs, dtype=torch.float + ) + + # Apply temperature scaling at data load time + # log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T) + position_logprobs_tensor = position_logprobs_tensor / self.temperature + position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp( + position_logprobs_tensor, dim=0, keepdim=True + ) + + position_logprobs_scaled = position_logprobs_tensor.tolist() + + target_logprobs.append(position_logprobs_scaled) + target_token_ids.append(position_token_ids) + + # Update sample with transformed logprobs + sample["target_logprobs"] = target_logprobs + sample["target_token_ids"] = target_token_ids + + return sample + + def tokenize_prompt(self, prompt): + tokenized_prompt = super().tokenize_prompt(prompt) + return self.transform_logprobs(tokenized_prompt) + + def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): # pylint: disable=duplicate-code ds_cfg = ds_cfg or {} @@ -491,7 +570,16 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None "train_on_eos": ds_cfg.get("train_on_eos", "turn"), } - strategy = ChatTemplateStrategy( + strategy_cls = ChatTemplateStrategy + if logprobs_field := ds_cfg.get("logprobs_field"): + strategy_params["logprobs_field"] = logprobs_field + if temperature := ds_cfg.get("temperature"): + strategy_params["temperature"] = temperature + + if cfg.trainer == "kd" or logprobs_field: + strategy_cls = ChatTemplateStrategyWithKD + + strategy = strategy_cls( ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4f368994a..83ede4514 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -177,6 +177,7 @@ class SFTDataset(BaseModel): message_field_content: Optional[str] = None message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None + logprobs_field: Optional[str] = None roles_to_train: Optional[List[str]] = None train_on_eos: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None @@ -621,6 +622,8 @@ class AxolotlInputConfig( bool ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + trainer: Optional[Literal["kd"]] = None + datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore shuffle_merged_datasets: Optional[bool] = True