refactor kd chat template loader

This commit is contained in:
Wing Lian
2024-12-30 12:57:11 -05:00
parent c51b0337c1
commit 27faacbf5a
7 changed files with 455 additions and 157 deletions

View File

@@ -0,0 +1,145 @@
"""
Chat template prompt strategy loader with KD support
"""
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
Handle fields for logprob KD
"""
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
):
self.logprobs_field = logprobs_field
self.gen_temperature = gen_temperature
self.kd_temperature = kd_temperature
super().__init__(
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
)
def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0])
target_logprobs = []
target_token_ids = []
target_mask = []
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for _ in range(target_seq_len):
# TODO also check against sample["labels"]
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def tokenize_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class KDStrategyLoader(StrategyLoader):
"""
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if gen_temperature := ds_cfg.get("temperature"):
strategy_params["gen_temperature"] = gen_temperature
if kd_temperature := cfg.get("kd_temperature"):
strategy_params["kd_temperature"] = kd_temperature
return strategy_params
load = KDStrategyLoader()

View File

@@ -0,0 +1,58 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -0,0 +1,84 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
loss for top_k KL divergence
"""
from typing import Optional
import torch
def loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch: Optional[int] = None,
kd_temperature: float = 1.0,
):
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
# Determine the teacher sequence length
# _, teacher_seq_len, top_k = target_token_ids.shape
teacher_seq_len = target_token_ids.shape[1]
# Slice student logits to match the teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
# shape -> [B, teacher_seq_len, K]
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)
# Apply KD temperature to students logits:
# z_s(T) = z_s / T
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, seq_len, K]
# Convert teacher_mask to boolean for indexing
valid_mask = target_mask.bool()
# Prune tensors to only keep valid tokens
# This will result in 1D arrays of only valid positions
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
teacher_probs = target_logprobs.exp()
# Compute forward KL:
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# 9) Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None:
# If you know how many items should be considered in the batch
kd_loss = kd_loss / num_items_in_batch
else:
# Otherwise, just average over all valid tokens
kd_loss = kd_loss / kd_loss_per_token.size(0)
return kd_loss

View File

@@ -0,0 +1,110 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
KD trainer
"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
if self._signature_columns:
if "target_logprobs" not in self._signature_columns:
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if "target_mask" not in self._signature_columns:
columns_to_add.append("target_mask")
if columns_to_add:
self._signature_columns += columns_to_add
def compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
if shift_targets:
shift_logits = student_logits[..., :-1, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
torch.cuda.empty_cache()
return (loss, outputs) if return_outputs else loss

View File

@@ -16,10 +16,18 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
package = "axolotl.prompt_strategies"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
else:
try:
importlib.import_module(".".join(strategy.split(".")[:-1]))
package = ".".join(strategy.split(".")[:-1])
strategy = strategy.split(".")[-1]
except ModuleNotFoundError:
pass
mod = importlib.import_module(f".{strategy}", package)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":

View File

@@ -5,7 +5,6 @@ HF Chat Templates prompt strategy
import logging
from typing import Any, Dict, List, Optional
import torch
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
@@ -460,168 +459,62 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return prompt.get(self.images, None)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
class StrategyLoader:
"""
Handle fields for logprob KD
Load chat template strategy based on configuration.
"""
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
def _get_strategy_cls(self):
return ChatTemplateStrategy
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
return {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
def __call__(
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
):
self.logprobs_field = logprobs_field
self.gen_temperature = gen_temperature
self.kd_temperature = kd_temperature
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
super().__init__(
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params),
tokenizer=tokenizer,
**strategy_params,
)
def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0])
target_logprobs = []
target_token_ids = []
target_mask = []
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for _ in range(target_seq_len):
# TODO also check against sample["labels"]
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def tokenize_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
return strategy
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy_cls = ChatTemplateStrategy
if cfg.trainer == "kd":
strategy_cls = ChatTemplateStrategyWithKD
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if gen_temperature := ds_cfg.get("temperature"):
strategy_params["gen_temperature"] = gen_temperature
if kd_temperature := cfg.get("kd_temperature"):
strategy_params["kd_temperature"] = kd_temperature
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
load = StrategyLoader()