Compare commits
4 Commits
v0.16.1
...
online-top
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8d52a2193 | ||
|
|
002b1ac967 | ||
|
|
17b01bfe36 | ||
|
|
a0669335e2 |
@@ -39,7 +39,10 @@ class KDPlugin(BasePlugin):
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
if cfg.kd_trainer:
|
||||
from .trainer import AxolotlKDTrainer
|
||||
from .trainer import AxolotlKDTrainer, AxolotlOnlineKDTrainer
|
||||
|
||||
if cfg.kd_online_server_base_url:
|
||||
return AxolotlOnlineKDTrainer
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
@@ -53,7 +53,9 @@ class KDArgs(BaseModel):
|
||||
kd_online_server: InferenceServerType | None = Field(
|
||||
default_factory=lambda: InferenceServerType.vllm
|
||||
)
|
||||
kd_online_server_model: str | None = None
|
||||
kd_online_timeout: int | None = 120
|
||||
kd_online_max_new_tokens: int | None = 2048
|
||||
kd_temperature_min: float | None = (
|
||||
None # kd temperature scheduling during online kd
|
||||
)
|
||||
@@ -74,3 +76,4 @@ class KDTrainingArgsMixin:
|
||||
kd_normalize_topk: float | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
kd_online_max_new_tokens: int | None = None
|
||||
|
||||
47
src/axolotl/integrations/kd/online_chat_template.py
Normal file
47
src/axolotl/integrations/kd/online_chat_template.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
# Configure the logger
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithOnlineKD(ChatTemplateStrategy):
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# batching doesn't work well for logprob data
|
||||
return False
|
||||
|
||||
def _get_messages(self, prompt):
|
||||
input_prompt = prompt.get("problem")
|
||||
return [
|
||||
{"role": "user", "content": input_prompt},
|
||||
]
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
tools = self._get_tools(prompt)
|
||||
input_ids = self.prompter.build_prompt(
|
||||
turns, tools=tools, add_generation_prompt=True
|
||||
) # type: ignore
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"prompts": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
|
||||
class OnlineKDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self, cfg):
|
||||
return ChatTemplateStrategyWithOnlineKD
|
||||
|
||||
|
||||
load = OnlineKDStrategyLoader()
|
||||
@@ -16,6 +16,14 @@
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
@@ -101,3 +109,214 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
loss = outputs.loss if hasattr(outputs, "loss") else outputs
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
class AxolotlOnlineKDTrainer(AxolotlKDTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=kwargs.get("kd_online_max_new_tokens"),
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
use_cache=False if kwargs.get("gradient_checkpointing") else True,
|
||||
pad_token_id=self.processing_class.pad_token_id,
|
||||
)
|
||||
# Set custom EOS tokens if they are specified by the model's generation
|
||||
# config. This is important for models with the Llama 3 chat template,
|
||||
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
||||
# turns or messages.
|
||||
if (
|
||||
hasattr(self.model.generation_config, "eos_token_id")
|
||||
and self.model.generation_config.eos_token_id is not None
|
||||
):
|
||||
self.generation_config.eos_token_id = (
|
||||
self.model.generation_config.eos_token_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
||||
# Generate output with respect to the prompt-only
|
||||
generated_outputs = model.generate(
|
||||
input_ids=inputs["prompts"],
|
||||
attention_mask=inputs.get("prompt_attention_mask", None),
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Get the generated token IDs
|
||||
generated_tokens = generated_outputs.sequences
|
||||
# Calculate new attention mask
|
||||
new_attention_mask = torch.ones_like(generated_tokens)
|
||||
new_labels = generated_tokens.clone()
|
||||
|
||||
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
||||
if pad_token_id is not None:
|
||||
new_labels[new_labels == pad_token_id] = -100
|
||||
new_attention_mask[generated_tokens == pad_token_id] = 0
|
||||
|
||||
return generated_tokens, new_attention_mask, new_labels
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
||||
|
||||
This method implements the on-policy learning approach described in the GKD paper. With probability
|
||||
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
|
||||
the original inputs.
|
||||
"""
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = (
|
||||
self.generate_on_policy_outputs(
|
||||
unwrapped_model,
|
||||
inputs,
|
||||
self.generation_config,
|
||||
self.processing_class.pad_token_id,
|
||||
)
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
|
||||
target_token_ids, target_logprobs, target_mask = self.get_teacher_logprobs(
|
||||
inputs["input_ids"], inputs["labels"]
|
||||
)
|
||||
inputs["target_token_ids"] = target_token_ids
|
||||
inputs["target_logprobs"] = target_logprobs
|
||||
inputs["target_mask"] = target_mask
|
||||
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
return loss
|
||||
|
||||
def get_teacher_logprobs(self, input_ids, labels):
|
||||
request_body = {
|
||||
"model": self.axolotl_cfg.kd_online_server_model,
|
||||
"prompt": input_ids,
|
||||
"logprobs": self.axolotl_cfg.kd_online_topk,
|
||||
"echo": True,
|
||||
"skip_special_tokens": False,
|
||||
"n": 1,
|
||||
"max_tokens": 0,
|
||||
"temperature": 1.0,
|
||||
}
|
||||
base_url = self.args.kd_online_server_base_url
|
||||
api_url = f"{base_url}/v1/completions"
|
||||
bearer_token = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
headers = {"Authorization": f"Bearer {bearer_token}"}
|
||||
response = requests.post(
|
||||
api_url, json=request_body, headers=headers, timeout=30
|
||||
)
|
||||
prompt_logprobs = response.choices[0].logprobs.top_logprobs[
|
||||
1:
|
||||
] # prune first null position
|
||||
return self.transform_logprobs(input_ids, labels, prompt_logprobs)
|
||||
|
||||
def transform_logprobs(self, input_ids, labels, logprobs):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(input_ids)
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if labels[position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
# teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
# dim=0, keepdim=True
|
||||
# )
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
return target_token_ids, target_logprobs, target_mask
|
||||
|
||||
@@ -320,7 +320,7 @@ class PatchManager:
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
if has_remote_code and self.cfg.trust_remote_code is not None:
|
||||
# If explicitly set in YAML, prefer that
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ def check_tensorboard(
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
rtol: float = 0.05,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
|
||||
Reference in New Issue
Block a user