refactor kd chat template loader
This commit is contained in:
145
src/axolotl/integrations/kd/chat_template.py
Normal file
145
src/axolotl/integrations/kd/chat_template.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Chat template prompt strategy loader with KD support
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
Handle fields for logprob KD
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
logprobs_field="logprobs",
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.gen_temperature = gen_temperature
|
||||
self.kd_temperature = kd_temperature
|
||||
|
||||
super().__init__(
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
)
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
top_k = len(logprobs[0])
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
for _ in range(1, input_padding_len): # start at 1 since this is causal
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for _ in range(target_seq_len):
|
||||
# TODO also check against sample["labels"]
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
# Convert to tensor
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
#
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class KDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategyWithKD
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if gen_temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["gen_temperature"] = gen_temperature
|
||||
if kd_temperature := cfg.get("kd_temperature"):
|
||||
strategy_params["kd_temperature"] = kd_temperature
|
||||
|
||||
return strategy_params
|
||||
|
||||
|
||||
load = KDStrategyLoader()
|
||||
58
src/axolotl/integrations/kd/topk_logprob/LICENSE.md
Normal file
58
src/axolotl/integrations/kd/topk_logprob/LICENSE.md
Normal file
@@ -0,0 +1,58 @@
|
||||
### AXOLOTL COMMUNITY LICENSE AGREEMENT
|
||||
|
||||
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
|
||||
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
|
||||
and conditions set forth in this Agreement.
|
||||
|
||||
1. Definitions
|
||||
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
|
||||
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
|
||||
which may be licensed separately by their respective authors and/or licensors.
|
||||
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
|
||||
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
|
||||
permits Plugin Integrations to integrate with the Axolotl service.
|
||||
2. Grant of License
|
||||
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
|
||||
- Licensee must comply with all the terms and conditions of this Agreement.
|
||||
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
|
||||
portions of the Software.
|
||||
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
|
||||
3. Restrictions
|
||||
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
|
||||
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
|
||||
third parties to fine-tune artificial intelligence models.
|
||||
3.2 Licensee shall not:
|
||||
- Use the Software for any illegal or unauthorized purpose.
|
||||
- Reverse engineer, decompile, or disassemble the Software.
|
||||
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
|
||||
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
|
||||
Software or interfere with any third-party use of the Software.
|
||||
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
|
||||
4. Intellectual Property Rights
|
||||
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
|
||||
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
|
||||
Licensee.
|
||||
5. Disclaimer of Warranty
|
||||
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
|
||||
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
|
||||
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
6. Termination
|
||||
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
|
||||
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
|
||||
copies in its possession.
|
||||
7. Governing Law
|
||||
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
|
||||
without regards to conflicts of laws provisions thereof.
|
||||
8. Entire Agreement
|
||||
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
|
||||
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
|
||||
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
|
||||
Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms
|
||||
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
|
||||
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
|
||||
bound by the terms and conditions of this Agreement.
|
||||
|
||||
This Agreement was last updated on August 23, 2024.
|
||||
84
src/axolotl/integrations/kd/topk_logprob/forward_kl.py
Normal file
84
src/axolotl/integrations/kd/topk_logprob/forward_kl.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
"""
|
||||
loss for top_k KL divergence
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
kd_temperature: float = 1.0,
|
||||
):
|
||||
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
|
||||
|
||||
# Determine the teacher sequence length
|
||||
# _, teacher_seq_len, top_k = target_token_ids.shape
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
# Slice student logits to match the teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
# shape -> [B, teacher_seq_len, K]
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# Apply KD temperature to student’s logits:
|
||||
# z_s(T) = z_s / T
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
) # [B, seq_len, K]
|
||||
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
valid_mask = target_mask.bool()
|
||||
|
||||
# Prune tensors to only keep valid tokens
|
||||
# This will result in 1D arrays of only valid positions
|
||||
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
|
||||
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]
|
||||
|
||||
# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
|
||||
teacher_probs = target_logprobs.exp()
|
||||
|
||||
# Compute forward KL:
|
||||
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# 9) Multiply by T^2 (classical KD scaling)
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or mean over valid tokens
|
||||
if num_items_in_batch is not None:
|
||||
# If you know how many items should be considered in the batch
|
||||
kd_loss = kd_loss / num_items_in_batch
|
||||
else:
|
||||
# Otherwise, just average over all valid tokens
|
||||
kd_loss = kd_loss / kd_loss_per_token.size(0)
|
||||
|
||||
return kd_loss
|
||||
110
src/axolotl/integrations/kd/trainer.py
Normal file
110
src/axolotl/integrations/kd/trainer.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
if self._signature_columns:
|
||||
if "target_logprobs" not in self._signature_columns:
|
||||
columns_to_add.append("target_logprobs")
|
||||
if "target_token_ids" not in self._signature_columns:
|
||||
columns_to_add.append("target_token_ids")
|
||||
if "target_mask" not in self._signature_columns:
|
||||
columns_to_add.append("target_mask")
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
shift_targets=False,
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||
|
||||
if shift_targets:
|
||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
else:
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs.contiguous()
|
||||
target_token_ids_for_loss = target_token_ids.contiguous()
|
||||
target_mask_for_loss = target_mask.contiguous()
|
||||
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
@@ -16,10 +16,18 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
|
||||
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
|
||||
load_fn = "load"
|
||||
package = "axolotl.prompt_strategies"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
else:
|
||||
try:
|
||||
importlib.import_module(".".join(strategy.split(".")[:-1]))
|
||||
package = ".".join(strategy.split(".")[:-1])
|
||||
strategy = strategy.split(".")[-1]
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
mod = importlib.import_module(f".{strategy}", package)
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
|
||||
@@ -5,7 +5,6 @@ HF Chat Templates prompt strategy
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
@@ -460,168 +459,62 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
class StrategyLoader:
|
||||
"""
|
||||
Handle fields for logprob KD
|
||||
Load chat template strategy based on configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
logprobs_field="logprobs",
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategy
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
return {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.gen_temperature = gen_temperature
|
||||
self.kd_temperature = kd_temperature
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
super().__init__(
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params),
|
||||
tokenizer=tokenizer,
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
top_k = len(logprobs[0])
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
for _ in range(1, input_padding_len): # start at 1 since this is causal
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for _ in range(target_seq_len):
|
||||
# TODO also check against sample["labels"]
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
# Convert to tensor
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
#
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
return strategy
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
|
||||
strategy_cls = ChatTemplateStrategy
|
||||
|
||||
if cfg.trainer == "kd":
|
||||
strategy_cls = ChatTemplateStrategyWithKD
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if gen_temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["gen_temperature"] = gen_temperature
|
||||
if kd_temperature := cfg.get("kd_temperature"):
|
||||
strategy_params["kd_temperature"] = kd_temperature
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
load = StrategyLoader()
|
||||
|
||||
Reference in New Issue
Block a user