Compare commits
77 Commits
kd-trainer
...
kd-trainer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f11227a35a | ||
|
|
c434951dd6 | ||
|
|
42d4732aaf | ||
|
|
2c9dfbed2e | ||
|
|
4e4a16cd8a | ||
|
|
67c1c8405e | ||
|
|
bded6df509 | ||
|
|
bb5e6f4b72 | ||
|
|
32258c247e | ||
|
|
04efcb102f | ||
|
|
483defb9ae | ||
|
|
35a84f2cb8 | ||
|
|
510cf45317 | ||
|
|
7232cbdeab | ||
|
|
e8fceb7091 | ||
|
|
a5e0671738 | ||
|
|
b9847553af | ||
|
|
513ec9e03b | ||
|
|
530347856d | ||
|
|
261e4fb619 | ||
|
|
158071e95f | ||
|
|
432f65f5e6 | ||
|
|
1d039f5486 | ||
|
|
b9a42b396f | ||
|
|
ff2fb0fc1b | ||
|
|
317f290186 | ||
|
|
ab690f3f01 | ||
|
|
47932f21c4 | ||
|
|
808328e041 | ||
|
|
6784822cfb | ||
|
|
684b38291f | ||
|
|
01896b1bde | ||
|
|
e659c01646 | ||
|
|
204d6c43b4 | ||
|
|
d3c2b7ce9d | ||
|
|
93dfff92f1 | ||
|
|
6e409d2d88 | ||
|
|
d5bc214300 | ||
|
|
92c6c1087e | ||
|
|
feed96f95e | ||
|
|
cba6165ae1 | ||
|
|
cdfcd69afa | ||
|
|
885653d52e | ||
|
|
27faacbf5a | ||
|
|
c51b0337c1 | ||
|
|
fa055f9f69 | ||
|
|
f60c623af0 | ||
|
|
746891eb5c | ||
|
|
f09b5da60b | ||
|
|
689e1c10ba | ||
|
|
a5c085e003 | ||
|
|
63146300b7 | ||
|
|
ca5e397fc5 | ||
|
|
3416302b0d | ||
|
|
7366efc4ca | ||
|
|
d8d817eaed | ||
|
|
c0757e8a20 | ||
|
|
e565694914 | ||
|
|
081928e55b | ||
|
|
dc90c93894 | ||
|
|
18a46c338a | ||
|
|
119d586cf4 | ||
|
|
c73acd7de0 | ||
|
|
0b59a242d4 | ||
|
|
ed490517da | ||
|
|
00ce77e7ef | ||
|
|
ae545e0165 | ||
|
|
b592c05b93 | ||
|
|
7fe0ad088b | ||
|
|
ddcf5c68b3 | ||
|
|
e633a12dbe | ||
|
|
d584354ee4 | ||
|
|
303cfa71aa | ||
|
|
88b3198894 | ||
|
|
8606093921 | ||
|
|
cba5a457d9 | ||
|
|
19cd83d408 |
@@ -59,7 +59,7 @@ VOLUME_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||||
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
|
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(cmd: str, run_folder: str):
|
def run_cmd(cmd: str, run_folder: str):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from datasets import Dataset
|
|||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
@@ -109,9 +109,9 @@ def load_preference_datasets(
|
|||||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""
|
"""
|
||||||
Loads one or more training or evaluation datasets for DPO training, calling
|
Loads one or more training or evaluation datasets for RL training using paired
|
||||||
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
|
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||||
information.
|
Optionally, logs out debug information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -121,7 +121,7 @@ def load_preference_datasets(
|
|||||||
Dataclass with fields for training and evaluation datasets and the computed
|
Dataclass with fields for training and evaluation datasets and the computed
|
||||||
`total_num_steps`.
|
`total_num_steps`.
|
||||||
"""
|
"""
|
||||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -697,6 +697,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||||
|
if self.cfg.kd_temperature is not None:
|
||||||
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"kd_zscore_base_temp"
|
||||||
|
] = self.cfg.kd_zscore_base_temp
|
||||||
|
|
||||||
training_args_cls = (
|
training_args_cls = (
|
||||||
AxolotlTrainingArguments
|
AxolotlTrainingArguments
|
||||||
|
|||||||
@@ -188,6 +188,13 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kd_zscore_base_temp: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
|||||||
@@ -31,3 +31,4 @@ class KDArgs(BaseModel):
|
|||||||
] = None # loss coefficient for cross-entropy loss during KD
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
|
|||||||
@@ -52,26 +52,62 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
train_on_eos=train_on_eos,
|
train_on_eos=train_on_eos,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_batched(self) -> bool:
|
||||||
|
# batching doesn't work well for logprob data
|
||||||
|
return False
|
||||||
|
|
||||||
def transform_logprobs(self, sample):
|
def transform_logprobs(self, sample):
|
||||||
|
"""
|
||||||
|
Transform logprobs to target format for KD training
|
||||||
|
"""
|
||||||
|
|
||||||
logprobs = sample.pop(self.logprobs_field)
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
input_padding_len = input_seq_len - target_seq_len
|
input_padding_len = input_seq_len - target_seq_len
|
||||||
top_k = len(logprobs[0])
|
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||||
|
top_k_vals = [
|
||||||
|
len(logprobs[i])
|
||||||
|
for i in range(len(logprobs))
|
||||||
|
if logprobs[i] is not None and len(logprobs[i])
|
||||||
|
]
|
||||||
|
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||||
|
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||||
|
top_k = min(max_top_k, min_top_k)
|
||||||
|
if top_k == 0:
|
||||||
|
raise ValueError("No non-zero top-k logprobs found.")
|
||||||
|
|
||||||
target_logprobs = []
|
target_logprobs = []
|
||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
target_mask = []
|
target_mask = []
|
||||||
|
|
||||||
|
if input_padding_len < 0:
|
||||||
|
# logprobs is longer than target_seq_len,
|
||||||
|
# so we need to slice from the left/beginning of logprobs
|
||||||
|
logprobs = logprobs[:-input_seq_len]
|
||||||
|
input_padding_len = 0
|
||||||
|
# target_seq_len = input_seq_len
|
||||||
|
|
||||||
|
# truncate the second dimension of the logprobs to top_k
|
||||||
|
logprobs = [row[:top_k] for row in logprobs]
|
||||||
|
|
||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
for _ in range(1, input_padding_len): # start at 1 since this is causal
|
|
||||||
|
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||||
|
# otherwise, we need to shift in the trainer
|
||||||
|
shift = 0
|
||||||
|
for _ in range(shift, input_padding_len):
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
for _ in range(target_seq_len):
|
for position in range(input_padding_len, input_seq_len):
|
||||||
# TODO also check against sample["labels"]
|
if sample["labels"][position] == -100:
|
||||||
target_mask.append([1] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
else:
|
||||||
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for _, token_pos_logprobs in enumerate(logprobs):
|
for _, token_pos_logprobs in enumerate(logprobs):
|
||||||
# Initialize collections for logprobs and token_ids
|
# Initialize collections for logprobs and token_ids
|
||||||
@@ -91,28 +127,28 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
position_token_ids.append(token_id)
|
position_token_ids.append(token_id)
|
||||||
|
|
||||||
# Convert to a tensor for easier manipulation
|
# Convert to a tensor for easier manipulation
|
||||||
# Convert to tensor
|
|
||||||
position_logprobs_tensor = torch.tensor(
|
position_logprobs_tensor = torch.tensor(
|
||||||
position_logprobs, dtype=torch.float
|
position_logprobs, dtype=torch.float
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||||
|
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||||
|
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||||
|
#
|
||||||
|
# Convert from log to probability
|
||||||
|
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||||
if self.kd_temperature != self.gen_temperature:
|
if self.kd_temperature != self.gen_temperature:
|
||||||
#
|
|
||||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
|
||||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
|
||||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
|
||||||
#
|
|
||||||
# Convert from log to probability
|
|
||||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
|
||||||
# Exponentiate by factor (T1 / T2)
|
# Exponentiate by factor (T1 / T2)
|
||||||
exponent = self.gen_temperature / self.kd_temperature
|
exponent = self.gen_temperature / self.kd_temperature
|
||||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||||
# Re-normalize
|
else:
|
||||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
teacher_probs_t2 = teacher_probs_t1
|
||||||
dim=0, keepdim=True
|
# Re-normalize
|
||||||
)
|
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||||
# Convert back to log
|
dim=0, keepdim=True
|
||||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
)
|
||||||
|
# Convert back to log
|
||||||
|
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||||
|
|
||||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||||
@@ -120,10 +156,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_logprobs.append(position_logprobs_scaled)
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
target_token_ids.append(position_token_ids)
|
target_token_ids.append(position_token_ids)
|
||||||
|
|
||||||
# since we started at index 1 for causal, we need one more padding token
|
if shift == 1:
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
# since we started at index 1 for causal, we need one more padding token
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_mask.append([0] * top_k)
|
target_token_ids.append(list(range(top_k)))
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
# Update sample with transformed logprobs
|
# Update sample with transformed logprobs
|
||||||
sample["target_logprobs"] = target_logprobs
|
sample["target_logprobs"] = target_logprobs
|
||||||
|
|||||||
@@ -16,6 +16,40 @@ loss for top_k KL divergence
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def zscore_standardize(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
base_temperature: float = 1.0,
|
||||||
|
eps: float = 1e-9,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Z-score standardize along the last dimension of `logits`.
|
||||||
|
i.e., for each [B, seq_len] row, across K entries:
|
||||||
|
z = (logits - mean) / std,
|
||||||
|
then scale by 1 / base_temperature if desired.
|
||||||
|
|
||||||
|
mask can be broadcastable or None. If None, we standardize all elements.
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
# shape: [B, seq_len, K]
|
||||||
|
# Mean and std over dim=-1
|
||||||
|
mean = logits.mean(dim=-1, keepdim=True)
|
||||||
|
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
||||||
|
else:
|
||||||
|
# If you have to exclude some tokens, multiply by mask, etc.
|
||||||
|
float_mask = mask.to(logits.dtype)
|
||||||
|
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
||||||
|
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
||||||
|
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
||||||
|
|
||||||
|
std = torch.sqrt(var.clamp_min(eps))
|
||||||
|
z = (logits - mean) / std
|
||||||
|
|
||||||
|
# Scale by 1 / base_temperature
|
||||||
|
z = z / base_temperature
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def loss(
|
def loss(
|
||||||
student_logits: torch.Tensor,
|
student_logits: torch.Tensor,
|
||||||
@@ -27,8 +61,23 @@ def loss(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
A KD loss function that is TorchScript-friendly.
|
A KD loss function that is TorchScript-friendly.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
student_logits (torch.Tensor): The logits of the student model.
|
||||||
|
Shape: [B, student_seq_len, vocab_size]
|
||||||
|
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
|
||||||
|
Shape: [B, teacher_seq_len, top_k]
|
||||||
|
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
|
||||||
|
Shape: [B, teacher_seq_len, top_k]
|
||||||
|
target_mask (torch.Tensor): The mask for valid tokens.
|
||||||
|
Shape: [B, teacher_seq_len, top_k]
|
||||||
|
num_items_in_batch (int, optional): The number of items in the batch.
|
||||||
|
kd_temperature (float, optional): The temperature for KD.
|
||||||
|
Default: 1.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
target_logprobs = target_logprobs.float()
|
||||||
|
|
||||||
# Determine the teacher sequence length
|
# Determine the teacher sequence length
|
||||||
# target_token_ids shape: [B, teacher_seq_len, K]
|
# target_token_ids shape: [B, teacher_seq_len, K]
|
||||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||||
@@ -44,6 +93,8 @@ def loss(
|
|||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
) # [B, teacher_seq_len, K]
|
) # [B, teacher_seq_len, K]
|
||||||
|
|
||||||
|
student_logits_topk = student_logits_topk.float()
|
||||||
|
|
||||||
# Apply KD temperature to student’s logits
|
# Apply KD temperature to student’s logits
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
student_logits_topk = student_logits_topk / kd_temperature
|
student_logits_topk = student_logits_topk / kd_temperature
|
||||||
@@ -80,3 +131,82 @@ def loss(
|
|||||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||||
|
|
||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
|
|
||||||
|
def topk_kd_loss_with_zscore(
|
||||||
|
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||||
|
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||||
|
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||||
|
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||||
|
kd_temperature: float = 1.0, # classic KD temperature
|
||||||
|
zscore_base_temp: float = 1.0, # from the paper
|
||||||
|
num_items_in_batch: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A variant of top_k KL divergence with Z-score scaling
|
||||||
|
from "Logit Standardization in Knowledge Distillation".
|
||||||
|
"""
|
||||||
|
|
||||||
|
target_logprobs = target_logprobs.float()
|
||||||
|
|
||||||
|
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||||
|
# 1) Gather the student's top-k logits to match teacher
|
||||||
|
student_logits_for_kd = student_logits[
|
||||||
|
:, :teacher_seq_len, :
|
||||||
|
] # [B, seq_len, vocab]
|
||||||
|
student_topk_logits = torch.gather(
|
||||||
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
|
) # [B, seq_len, K]
|
||||||
|
|
||||||
|
student_topk_logits = student_topk_logits.float()
|
||||||
|
|
||||||
|
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||||
|
if kd_temperature != 1.0:
|
||||||
|
student_topk_logits = student_topk_logits / kd_temperature
|
||||||
|
|
||||||
|
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||||
|
# (They differ by +some_constant from real logits, but in z-score
|
||||||
|
# that constant is subtracted out anyway.)
|
||||||
|
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||||
|
|
||||||
|
# 4) Z-score teacher and student
|
||||||
|
# If target_mask is 2D, expand to 3D for the K dimension
|
||||||
|
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||||
|
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||||
|
|
||||||
|
teacher_z = zscore_standardize(
|
||||||
|
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||||
|
)
|
||||||
|
student_z = zscore_standardize(
|
||||||
|
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5) Convert to log-probs for KL
|
||||||
|
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
||||||
|
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 6) Restrict to valid tokens if needed
|
||||||
|
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||||
|
teacher_probs_z = teacher_logprobs_z.exp()
|
||||||
|
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||||
|
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||||
|
student_logprobs_z = student_logprobs_z[valid_mask]
|
||||||
|
|
||||||
|
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
||||||
|
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
||||||
|
kd_loss = kd_loss_per_token.sum()
|
||||||
|
|
||||||
|
# 8) If using classical KD scaling by T^2
|
||||||
|
if kd_temperature != 1.0:
|
||||||
|
kd_loss = kd_loss * (kd_temperature**2)
|
||||||
|
|
||||||
|
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
||||||
|
# kd_loss = kd_loss * (zscore_base_temp**2)
|
||||||
|
|
||||||
|
# 9) Normalize
|
||||||
|
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||||
|
kd_loss = kd_loss / float(num_items_in_batch)
|
||||||
|
else:
|
||||||
|
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||||
|
|
||||||
|
return kd_loss
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ KD trainer
|
|||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||||
|
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
@@ -45,7 +46,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
inputs,
|
inputs,
|
||||||
return_outputs=False,
|
return_outputs=False,
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None,
|
||||||
shift_targets=False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
@@ -69,25 +69,30 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
# FIXME: account for tokenizer.padding_side
|
# FIXME: account for tokenizer.padding_side
|
||||||
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
|
||||||
|
|
||||||
if shift_targets:
|
shift_logits = student_logits.contiguous()
|
||||||
shift_logits = student_logits[..., :-1, :].contiguous()
|
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = student_logits.contiguous()
|
|
||||||
target_logprobs_for_loss = target_logprobs.contiguous()
|
|
||||||
target_token_ids_for_loss = target_token_ids.contiguous()
|
|
||||||
target_mask_for_loss = target_mask.contiguous()
|
|
||||||
|
|
||||||
loss_kd = topk_kd_loss(
|
if self.args.kd_zscore_base_temp:
|
||||||
shift_logits,
|
loss_kd = topk_kd_loss_with_zscore(
|
||||||
target_token_ids_for_loss,
|
shift_logits,
|
||||||
target_logprobs_for_loss,
|
target_token_ids_for_loss,
|
||||||
target_mask_for_loss,
|
target_logprobs_for_loss,
|
||||||
num_items_in_batch=num_items_in_batch,
|
target_mask_for_loss,
|
||||||
kd_temperature=self.args.kd_temperature,
|
kd_temperature=self.args.kd_temperature,
|
||||||
)
|
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||||
|
num_items_in_batch=num_items_in_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss_kd = topk_kd_loss(
|
||||||
|
shift_logits,
|
||||||
|
target_token_ids_for_loss,
|
||||||
|
target_logprobs_for_loss,
|
||||||
|
target_mask_for_loss,
|
||||||
|
num_items_in_batch=num_items_in_batch,
|
||||||
|
kd_temperature=self.args.kd_temperature,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
if self.args.kd_ce_alpha > 0:
|
||||||
kd_alpha = self.args.kd_alpha
|
kd_alpha = self.args.kd_alpha
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
|
|||||||
encode_pretraining,
|
encode_pretraining,
|
||||||
wrap_pretraining_dataset,
|
wrap_pretraining_dataset,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
|
||||||
from axolotl.utils.data.sft import ( # noqa: F401
|
from axolotl.utils.data.sft import ( # noqa: F401
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
load_prepare_datasets,
|
load_prepare_datasets,
|
||||||
|
|||||||
@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
max_tokens: int,
|
||||||
|
examples: Dict[str, List],
|
||||||
|
text_column: str = "text",
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples["text"],
|
examples[text_column],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_tokens - 2,
|
max_length=max_tokens - 2,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
|
|||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
cfg.micro_batch_size = 1
|
cfg.micro_batch_size = 1
|
||||||
else:
|
else:
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
encode = functools.partial(
|
||||||
|
encode_pretraining,
|
||||||
|
tokenizer,
|
||||||
|
max_tokens,
|
||||||
|
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ def drop_long_rl_seq(
|
|||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_dpo_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||||
|
|||||||
@@ -1057,7 +1057,7 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
hasattr(self.model, "get_input_embeddings")
|
hasattr(self.model, "get_input_embeddings")
|
||||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
|
|||||||
@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
|
batched=True,
|
||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
"""
|
"""
|
||||||
labels = sample["labels"]
|
labels = sample["labels"]
|
||||||
if not labels:
|
if not labels:
|
||||||
# Edge case: if labels is empty, decide if you want to keep or drop
|
return True
|
||||||
return True # or False
|
|
||||||
|
|
||||||
# Check if single example or batch
|
# Check if single example or batch
|
||||||
# If first element is an int, we assume a single example
|
# If first element is an int, we assume a single example
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ def min_cfg(temp_dir):
|
|||||||
"dataloader_prefetch_factor": 8,
|
"dataloader_prefetch_factor": 8,
|
||||||
"dataloader_num_workers": 4,
|
"dataloader_num_workers": 4,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
|
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"),
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
|
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ E2E tests for llama pretrain
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
@@ -12,19 +13,22 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestPretrainLlama(unittest.TestCase):
|
class TestPretrainLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models w pretraining
|
Test case for Llama models w pretraining
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_pretrain_w_sample_packing(self, temp_dir):
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_pretrain(self, temp_dir, sample_packing):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
|
|||||||
"tokenizer_type": "LlamaTokenizer",
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": sample_packing,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
assert len(train_dataset) == 1800
|
||||||
assert "conversation" in train_dataset.features
|
assert "conversation" in train_dataset.features
|
||||||
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
assert len(train_dataset) == 1800
|
||||||
assert "conversation" in train_dataset.features
|
assert "conversation" in train_dataset.features
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from datasets import Dataset
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
|||||||
"""Verify that loading with deduplication removes duplicates."""
|
"""Verify that loading with deduplication removes duplicates."""
|
||||||
|
|
||||||
# Load the dataset using the deduplication setting
|
# Load the dataset using the deduplication setting
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||||
|
|
||||||
# Verify that the dataset has been deduplicated
|
# Verify that the dataset has been deduplicated
|
||||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||||
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
|||||||
"""Verify that loading without deduplication retains duplicates."""
|
"""Verify that loading without deduplication retains duplicates."""
|
||||||
self.cfg.dataset_exact_deduplication = False
|
self.cfg.dataset_exact_deduplication = False
|
||||||
# Load the dataset without deduplication
|
# Load the dataset without deduplication
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||||
|
|
||||||
# Verify that the dataset retains duplicates
|
# Verify that the dataset retains duplicates
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
Reference in New Issue
Block a user