Compare commits

..

77 Commits

Author SHA1 Message Date
Wing Lian
f11227a35a various fixes 2025-01-30 10:39:18 -05:00
Wing Lian
c434951dd6 Always re-normalize teacher distribution 2025-01-29 08:36:40 -05:00
Wing Lian
42d4732aaf kd loss needs to be calculated in full precision 2025-01-28 19:40:35 -05:00
Wing Lian
2c9dfbed2e apply z-score scaling to kd 2025-01-27 14:27:35 -05:00
Wing Lian
4e4a16cd8a fix finding the top-k rather than assuming first position has the correct val 2025-01-21 13:09:20 -05:00
Wing Lian
67c1c8405e use iter instead of tuple 2025-01-21 11:23:38 -05:00
Wing Lian
bded6df509 change up logic so we always truncate to top_k 2025-01-21 11:20:01 -05:00
Wing Lian
bb5e6f4b72 make sure to truncate logprobs if there are more than top_k 2025-01-21 10:26:27 -05:00
Wing Lian
32258c247e no batching for kd chat templates 2025-01-15 08:22:29 -05:00
Wing Lian
04efcb102f don't shift student logits for kd 2025-01-15 01:07:48 -05:00
Wing Lian
483defb9ae try tests for kd on l40s 2025-01-14 23:56:00 -05:00
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
Wing Lian
7232cbdeab chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
e8fceb7091 chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
a5e0671738 make sure to use tensorboard to capture loss for checks 2025-01-14 22:47:48 -05:00
Wing Lian
b9847553af fix adapter model check 2025-01-14 22:47:48 -05:00
Wing Lian
513ec9e03b make sure to use the correct tokenizer 2025-01-14 22:47:48 -05:00
Wing Lian
530347856d make sure to set tokenizer from l3 70b and save safetensors 2025-01-14 22:47:47 -05:00
Wing Lian
261e4fb619 lower lr 2025-01-14 22:47:47 -05:00
Wing Lian
158071e95f set lora_dropout explicitly 2025-01-14 22:47:47 -05:00
Wing Lian
432f65f5e6 make the kd e2e fit in vram for ci and add lora version 2025-01-14 22:47:47 -05:00
Wing Lian
1d039f5486 rename test files so it gets picked up 2025-01-14 22:47:47 -05:00
Wing Lian
b9a42b396f linting 2025-01-14 22:47:47 -05:00
Wing Lian
ff2fb0fc1b add kd trainer e2e test 2025-01-14 22:47:47 -05:00
Wing Lian
317f290186 reward model doesn't work well with batched 2025-01-14 22:47:46 -05:00
Wing Lian
ab690f3f01 improve check for batched 2025-01-14 22:47:46 -05:00
Wing Lian
47932f21c4 fix reward trainer calls for tokenization 2025-01-14 22:47:46 -05:00
Wing Lian
808328e041 reward can use same batch check 2025-01-14 22:47:46 -05:00
Wing Lian
6784822cfb tweak check for batched prompt data 2025-01-14 22:47:46 -05:00
Wing Lian
684b38291f ensure that batch vs single is done properly 2025-01-14 22:47:46 -05:00
Wing Lian
01896b1bde improve iterable support 2025-01-14 22:47:46 -05:00
Wing Lian
e659c01646 support streaming for processing sft datasts? 2025-01-14 22:47:45 -05:00
Wing Lian
204d6c43b4 make loss torch script compat 2025-01-14 22:47:45 -05:00
Wing Lian
d3c2b7ce9d kd sample packing 2025-01-14 22:47:45 -05:00
Wing Lian
93dfff92f1 be a bit pickier about loading dynamic prompt strategies 2025-01-14 22:47:45 -05:00
Wing Lian
6e409d2d88 more info on preprocess for kd and fix import 2025-01-14 22:47:45 -05:00
Wing Lian
d5bc214300 remove duplicate code 2025-01-14 22:47:45 -05:00
Wing Lian
92c6c1087e add copyrights 2025-01-14 22:47:45 -05:00
Wing Lian
feed96f95e increase logging around loading plugins 2025-01-14 22:47:44 -05:00
Wing Lian
cba6165ae1 make plugin setup concise 2025-01-14 22:47:44 -05:00
Wing Lian
cdfcd69afa remove moved class from import 2025-01-14 22:47:44 -05:00
Wing Lian
885653d52e move more things to kd plugin 2025-01-14 22:47:44 -05:00
Wing Lian
27faacbf5a refactor kd chat template loader 2025-01-14 22:47:44 -05:00
Wing Lian
c51b0337c1 support for custom trainer classes from plugins 2025-01-14 22:47:44 -05:00
Wing Lian
fa055f9f69 handle token/logprob shifting 2025-01-14 22:47:43 -05:00
Wing Lian
f60c623af0 remove references to triton kd for now 2025-01-14 22:47:43 -05:00
Wing Lian
746891eb5c add license block 2025-01-14 22:47:43 -05:00
Wing Lian
f09b5da60b refactor so we can easily add new loss functions 2025-01-14 22:47:43 -05:00
Wing Lian
689e1c10ba chore: lint 2025-01-14 22:47:43 -05:00
Wing Lian
a5c085e003 var naming and add todo 2025-01-14 22:47:43 -05:00
Wing Lian
63146300b7 fix kd loss so it's causal (fixes repeating tokens) 2025-01-14 22:47:43 -05:00
Wing Lian
ca5e397fc5 use kd_alpha in the correct loss method 2025-01-14 22:47:42 -05:00
Wing Lian
3416302b0d hash for temperature too 2025-01-14 22:47:42 -05:00
Wing Lian
7366efc4ca better rescaling for temperatures 2025-01-14 22:47:42 -05:00
Wing Lian
d8d817eaed don't use triton for now 2025-01-14 22:47:42 -05:00
Wing Lian
c0757e8a20 fix kwarg 2025-01-14 22:47:42 -05:00
Wing Lian
e565694914 v3 2025-01-14 22:47:42 -05:00
Wing Lian
081928e55b no torch.tensor 2025-01-14 22:47:42 -05:00
Wing Lian
dc90c93894 no log etc 2025-01-14 22:47:41 -05:00
Wing Lian
18a46c338a no torch.exp inside triton kernel 2025-01-14 22:47:41 -05:00
Wing Lian
119d586cf4 v2 trial 2025-01-14 22:47:41 -05:00
Wing Lian
c73acd7de0 no where support 2025-01-14 22:47:41 -05:00
Wing Lian
0b59a242d4 triton wip 2025-01-14 22:47:41 -05:00
Wing Lian
ed490517da chore: lint 2025-01-14 22:47:41 -05:00
Wing Lian
00ce77e7ef make sure to multiply against the correct loss 2025-01-14 22:47:41 -05:00
Wing Lian
ae545e0165 cross entropy loss coefficient during KD 2025-01-14 22:47:40 -05:00
Wing Lian
b592c05b93 flipped the slice 2025-01-14 22:47:40 -05:00
Wing Lian
7fe0ad088b make it work 2025-01-14 22:47:40 -05:00
Wing Lian
ddcf5c68b3 handle padding/collation for KD datasets 2025-01-14 22:47:40 -05:00
Wing Lian
e633a12dbe make batch smaller 2025-01-14 22:47:40 -05:00
Wing Lian
d584354ee4 filter bad rows 2025-01-14 22:47:40 -05:00
Wing Lian
303cfa71aa KD dataset loading and KD with logprobs 2025-01-14 22:47:40 -05:00
Wing Lian
88b3198894 refactor trainer to prevent circular dependencies later
fix loader default
2025-01-14 22:47:39 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
17 changed files with 267 additions and 68 deletions

View File

@@ -59,7 +59,7 @@ VOLUME_CONFIG = {
} }
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS) GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str): def run_cmd(cmd: str, run_folder: str):

View File

@@ -11,7 +11,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
@@ -109,9 +109,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets for DPO training, calling Loads one or more training or evaluation datasets for RL training using paired
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
information. Optionally, logs out debug information.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -121,7 +121,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`. `total_num_steps`.
""" """
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )

View File

@@ -697,6 +697,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_args_cls = ( training_args_cls = (
AxolotlTrainingArguments AxolotlTrainingArguments

View File

@@ -188,6 +188,13 @@ class AxolotlTrainingMixins:
}, },
) )
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -31,3 +31,4 @@ class KDArgs(BaseModel):
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling

View File

@@ -52,26 +52,62 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=train_on_eos, train_on_eos=train_on_eos,
) )
@property
def supports_batched(self) -> bool:
# batching doesn't work well for logprob data
return False
def transform_logprobs(self, sample): def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field) logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len input_padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0]) # get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = [] target_logprobs = []
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for _ in range(target_seq_len): for position in range(input_padding_len, input_seq_len):
# TODO also check against sample["labels"] if sample["labels"][position] == -100:
target_mask.append([1] * top_k) target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
@@ -91,28 +127,28 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
position_token_ids.append(token_id) position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation # Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor( position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float position_logprobs, dtype=torch.float
) )
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
if self.kd_temperature != self.gen_temperature: if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2) # Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize else:
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( teacher_probs_t2 = teacher_probs_t1
dim=0, keepdim=True # Re-normalize
) teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
# Convert back to log dim=0, keepdim=True
position_logprobs_tensor = torch.log(teacher_probs_t2) )
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist() position_logprobs_scaled = position_logprobs_tensor.tolist()
@@ -120,10 +156,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids) target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token if shift == 1:
target_logprobs.append([-float("inf")] * top_k) # since we started at index 1 for causal, we need one more padding token
target_token_ids.append(list(range(top_k))) target_logprobs.append([-float("inf")] * top_k)
target_mask.append([0] * top_k) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs # Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs sample["target_logprobs"] = target_logprobs

View File

@@ -16,6 +16,40 @@ loss for top_k KL divergence
import torch import torch
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script @torch.jit.script
def loss( def loss(
student_logits: torch.Tensor, student_logits: torch.Tensor,
@@ -27,8 +61,23 @@ def loss(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
A KD loss function that is TorchScript-friendly. A KD loss function that is TorchScript-friendly.
Arguments:
student_logits (torch.Tensor): The logits of the student model.
Shape: [B, student_seq_len, vocab_size]
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
Shape: [B, teacher_seq_len, top_k]
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
Shape: [B, teacher_seq_len, top_k]
target_mask (torch.Tensor): The mask for valid tokens.
Shape: [B, teacher_seq_len, top_k]
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
""" """
target_logprobs = target_logprobs.float()
# Determine the teacher sequence length # Determine the teacher sequence length
# target_token_ids shape: [B, teacher_seq_len, K] # target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size] # student_logits shape: [B, student_seq_len, vocab_size]
@@ -44,6 +93,8 @@ def loss(
student_logits_for_kd, dim=-1, index=target_token_ids student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K] ) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Apply KD temperature to students logits # Apply KD temperature to students logits
if kd_temperature != 1.0: if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature student_logits_topk = student_logits_topk / kd_temperature
@@ -80,3 +131,82 @@ def loss(
kd_loss = kd_loss / float(kd_loss_per_token.size(0)) kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss return kd_loss
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
target_logprobs = target_logprobs.float()
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float()
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -19,6 +19,7 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
@@ -45,7 +46,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None,
shift_targets=False,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
@@ -69,25 +69,30 @@ class AxolotlKDTrainer(AxolotlTrainer):
# FIXME: account for tokenizer.padding_side # FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous() student_logits = outputs["logits"][:, :seq_len, :].contiguous()
if shift_targets: shift_logits = student_logits.contiguous()
shift_logits = student_logits[..., :-1, :].contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss( if self.args.kd_zscore_base_temp:
shift_logits, loss_kd = topk_kd_loss_with_zscore(
target_token_ids_for_loss, shift_logits,
target_logprobs_for_loss, target_token_ids_for_loss,
target_mask_for_loss, target_logprobs_for_loss,
num_items_in_batch=num_items_in_batch, target_mask_for_loss,
kd_temperature=self.args.kd_temperature, kd_temperature=self.args.kd_temperature,
) zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha kd_alpha = self.args.kd_alpha

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples["text"], examples[text_column],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
else: else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg): def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs): for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -1057,7 +1057,7 @@ class ModelLoader:
) )
if ( if (
hasattr(self.model, "get_input_embeddings") hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings < embeddings_len and self.model.get_input_embeddings().num_embeddings != embeddings_len
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences" drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
batched=True,
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
""" """
labels = sample["labels"] labels = sample["labels"]
if not labels: if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop return True
return True # or False
# Check if single example or batch # Check if single example or batch
# If first element is an int, we assume a single example # If first element is an int, we assume a single example

View File

@@ -33,6 +33,7 @@ def min_cfg(temp_dir):
"dataloader_prefetch_factor": 8, "dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4, "dataloader_num_workers": 4,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"),
"datasets": [ "datasets": [
{ {
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",

View File

@@ -4,7 +4,8 @@ E2E tests for llama pretrain
import logging import logging
import os import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -12,19 +13,22 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase): class TestPretrainLlama:
""" """
Test case for Llama models w pretraining Test case for Llama models w pretraining
""" """
@with_temp_dir @pytest.mark.parametrize(
def test_pretrain_w_sample_packing(self, temp_dir): "sample_packing",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": sample_packing,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features

View File

@@ -12,7 +12,7 @@ from datasets import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated # Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading without deduplication retains duplicates.""" """Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication # Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates # Verify that the dataset retains duplicates
assert ( assert (