Compare commits

...

13 Commits

Author SHA1 Message Date
Wing Lian
f11227a35a various fixes 2025-01-30 10:39:18 -05:00
Wing Lian
c434951dd6 Always re-normalize teacher distribution 2025-01-29 08:36:40 -05:00
Wing Lian
42d4732aaf kd loss needs to be calculated in full precision 2025-01-28 19:40:35 -05:00
Wing Lian
2c9dfbed2e apply z-score scaling to kd 2025-01-27 14:27:35 -05:00
Wing Lian
4e4a16cd8a fix finding the top-k rather than assuming first position has the correct val 2025-01-21 13:09:20 -05:00
Wing Lian
67c1c8405e use iter instead of tuple 2025-01-21 11:23:38 -05:00
Wing Lian
bded6df509 change up logic so we always truncate to top_k 2025-01-21 11:20:01 -05:00
Wing Lian
bb5e6f4b72 make sure to truncate logprobs if there are more than top_k 2025-01-21 10:26:27 -05:00
Wing Lian
32258c247e no batching for kd chat templates 2025-01-15 08:22:29 -05:00
Wing Lian
04efcb102f don't shift student logits for kd 2025-01-15 01:07:48 -05:00
Wing Lian
483defb9ae try tests for kd on l40s 2025-01-14 23:56:00 -05:00
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
9 changed files with 232 additions and 45 deletions

View File

@@ -59,7 +59,7 @@ VOLUME_CONFIG = {
} }
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS) GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str): def run_cmd(cmd: str, run_folder: str):

View File

@@ -697,6 +697,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_args_cls = ( training_args_cls = (
AxolotlTrainingArguments AxolotlTrainingArguments

View File

@@ -188,6 +188,13 @@ class AxolotlTrainingMixins:
}, },
) )
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -31,3 +31,4 @@ class KDArgs(BaseModel):
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling

View File

@@ -52,26 +52,62 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=train_on_eos, train_on_eos=train_on_eos,
) )
@property
def supports_batched(self) -> bool:
# batching doesn't work well for logprob data
return False
def transform_logprobs(self, sample): def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field) logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len input_padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0]) # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = [] target_logprobs = []
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for _ in range(target_seq_len): for position in range(input_padding_len, input_seq_len):
# TODO also check against sample["labels"] if sample["labels"][position] == -100:
target_mask.append([1] * top_k) target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
@@ -91,28 +127,28 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
position_token_ids.append(token_id) position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float position_logprobs, dtype=torch.float
) )
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
if self.kd_temperature != self.gen_temperature: if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2) # Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize else:
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t1
dim=0, keepdim=True # Re-normalize
) teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
# Convert back to log dim=0, keepdim=True
position_logprobs_tensor = torch.log(teacher_probs_t2) )
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist() position_logprobs_scaled = position_logprobs_tensor.tolist()
@@ -120,10 +156,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids) target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token if shift == 1:
target_logprobs.append([-float("inf")] * top_k) # since we started at index 1 for causal, we need one more padding token
target_token_ids.append(list(range(top_k))) target_logprobs.append([-float("inf")] * top_k)
target_mask.append([0] * top_k) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs # Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs sample["target_logprobs"] = target_logprobs

View File

@@ -16,6 +16,40 @@ loss for top_k KL divergence
import torch import torch
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script @torch.jit.script
def loss( def loss(
student_logits: torch.Tensor, student_logits: torch.Tensor,
@@ -27,8 +61,23 @@ def loss(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
A KD loss function that is TorchScript-friendly. A KD loss function that is TorchScript-friendly.
Arguments:
student_logits (torch.Tensor): The logits of the student model.
Shape: [B, student_seq_len, vocab_size]
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
Shape: [B, teacher_seq_len, top_k]
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
Shape: [B, teacher_seq_len, top_k]
target_mask (torch.Tensor): The mask for valid tokens.
Shape: [B, teacher_seq_len, top_k]
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
""" """
target_logprobs = target_logprobs.float()
# Determine the teacher sequence length # Determine the teacher sequence length
# target_token_ids shape: [B, teacher_seq_len, K] # target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size] # student_logits shape: [B, student_seq_len, vocab_size]
@@ -44,6 +93,8 @@ def loss(
student_logits_for_kd, dim=-1, index=target_token_ids student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K] ) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Apply KD temperature to students logits # Apply KD temperature to students logits
if kd_temperature != 1.0: if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature student_logits_topk = student_logits_topk / kd_temperature
@@ -80,3 +131,82 @@ def loss(
kd_loss = kd_loss / float(kd_loss_per_token.size(0)) kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss return kd_loss
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
target_logprobs = target_logprobs.float()
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float()
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -19,6 +19,7 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
@@ -45,7 +46,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None,
shift_targets=False,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
@@ -69,25 +69,30 @@ class AxolotlKDTrainer(AxolotlTrainer):
# FIXME: account for tokenizer.padding_side # FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous() student_logits = outputs["logits"][:, :seq_len, :].contiguous()
if shift_targets: shift_logits = student_logits.contiguous()
shift_logits = student_logits[..., :-1, :].contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss( if self.args.kd_zscore_base_temp:
shift_logits, loss_kd = topk_kd_loss_with_zscore(
target_token_ids_for_loss, shift_logits,
target_logprobs_for_loss, target_token_ids_for_loss,
target_mask_for_loss, target_logprobs_for_loss,
num_items_in_batch=num_items_in_batch, target_mask_for_loss,
kd_temperature=self.args.kd_temperature, kd_temperature=self.args.kd_temperature,
) zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha kd_alpha = self.args.kd_alpha

View File

@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences" drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
batched=True,
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
""" """
labels = sample["labels"] labels = sample["labels"]
if not labels: if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop return True
return True # or False
# Check if single example or batch # Check if single example or batch
# If first element is an int, we assume a single example # If first element is an int, we assume a single example

View File

@@ -33,6 +33,7 @@ def min_cfg(temp_dir):
"dataloader_prefetch_factor": 8, "dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4, "dataloader_num_workers": 4,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"),
"datasets": [ "datasets": [
{ {
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",