Compare commits

..

61 Commits

Author SHA1 Message Date
Wing Lian
4a0ab11fcf chore: lint 2025-01-13 14:05:56 -05:00
Wing Lian
73b6b0a580 chore: lint 2025-01-13 13:56:16 -05:00
Wing Lian
9db5072407 make sure to use tensorboard to capture loss for checks 2025-01-13 13:56:16 -05:00
Wing Lian
42d3e36a6f fix adapter model check 2025-01-13 13:56:15 -05:00
Wing Lian
b12d93bedf make sure to use the correct tokenizer 2025-01-13 13:56:15 -05:00
Wing Lian
08ec9c0e5b make sure to set tokenizer from l3 70b and save safetensors 2025-01-13 13:56:15 -05:00
Wing Lian
9abac55f92 lower lr 2025-01-13 13:56:15 -05:00
Wing Lian
800e7fa41e set lora_dropout explicitly 2025-01-13 13:56:15 -05:00
Wing Lian
5a1c1b82d4 make the kd e2e fit in vram for ci and add lora version 2025-01-13 13:56:15 -05:00
Wing Lian
efb3f70d38 rename test files so it gets picked up 2025-01-13 13:56:15 -05:00
Wing Lian
58d9896777 linting 2025-01-13 13:56:15 -05:00
Wing Lian
f7963083b8 add kd trainer e2e test 2025-01-13 13:56:15 -05:00
Wing Lian
f0b6581f8c reward model doesn't work well with batched 2025-01-13 13:56:15 -05:00
Wing Lian
27bb21c459 improve check for batched 2025-01-13 13:56:15 -05:00
Wing Lian
74d98ca6d8 fix reward trainer calls for tokenization 2025-01-13 13:56:14 -05:00
Wing Lian
ec4dfb02c8 reward can use same batch check 2025-01-13 13:56:14 -05:00
Wing Lian
28ef5e8d5a tweak check for batched prompt data 2025-01-13 13:56:14 -05:00
Wing Lian
5ed2823855 ensure that batch vs single is done properly 2025-01-13 13:56:14 -05:00
Wing Lian
fb0775d264 improve iterable support 2025-01-13 13:56:12 -05:00
Wing Lian
7cd0a317cb support streaming for processing sft datasts? 2025-01-13 13:41:36 -05:00
Wing Lian
1cc3a2d16c make loss torch script compat 2025-01-13 13:41:36 -05:00
Wing Lian
287d2ca8d5 kd sample packing 2025-01-13 13:41:36 -05:00
Wing Lian
03b86df506 be a bit pickier about loading dynamic prompt strategies 2025-01-13 13:41:36 -05:00
Wing Lian
2ed4246949 more info on preprocess for kd and fix import 2025-01-13 13:41:35 -05:00
Wing Lian
35bc2e2d3f remove duplicate code 2025-01-13 13:41:35 -05:00
Wing Lian
94f1094805 add copyrights 2025-01-13 13:41:35 -05:00
Wing Lian
a0070bf94e increase logging around loading plugins 2025-01-13 13:41:35 -05:00
Wing Lian
2ee2ffd834 make plugin setup concise 2025-01-13 13:41:35 -05:00
Wing Lian
723b0a2dee remove moved class from import 2025-01-13 13:41:35 -05:00
Wing Lian
327739c9e3 move more things to kd plugin 2025-01-13 13:41:35 -05:00
Wing Lian
8aafe142f2 refactor kd chat template loader 2025-01-13 13:41:35 -05:00
Wing Lian
a0d6d8895e support for custom trainer classes from plugins 2025-01-13 13:41:34 -05:00
Wing Lian
55b33cc44d handle token/logprob shifting 2025-01-13 13:41:34 -05:00
Wing Lian
69ed25e82c remove references to triton kd for now 2025-01-13 13:41:34 -05:00
Wing Lian
2ea8b7e518 add license block 2025-01-13 13:41:34 -05:00
Wing Lian
aa081e0e76 refactor so we can easily add new loss functions 2025-01-13 13:41:34 -05:00
Wing Lian
3f97ec45fb chore: lint 2025-01-13 13:41:34 -05:00
Wing Lian
7b5a24b0d2 var naming and add todo 2025-01-13 13:41:34 -05:00
Wing Lian
4ddd089d0a fix kd loss so it's causal (fixes repeating tokens) 2025-01-13 13:41:34 -05:00
Wing Lian
b88128d067 use kd_alpha in the correct loss method 2025-01-13 13:41:32 -05:00
Wing Lian
2e6422a711 hash for temperature too 2025-01-13 13:40:19 -05:00
Wing Lian
6ad809287b better rescaling for temperatures 2025-01-13 13:40:19 -05:00
Wing Lian
e376e00386 don't use triton for now 2025-01-13 13:40:19 -05:00
Wing Lian
23d7ae6caa fix kwarg 2025-01-13 13:40:19 -05:00
Wing Lian
19638590d5 v3 2025-01-13 13:40:18 -05:00
Wing Lian
73f5b83431 no torch.tensor 2025-01-13 13:40:18 -05:00
Wing Lian
9b1164b841 no log etc 2025-01-13 13:40:18 -05:00
Wing Lian
5a7d6f6175 no torch.exp inside triton kernel 2025-01-13 13:40:18 -05:00
Wing Lian
a803c3d3ee v2 trial 2025-01-13 13:40:18 -05:00
Wing Lian
48ccf55752 no where support 2025-01-13 13:40:18 -05:00
Wing Lian
bc3326a808 triton wip 2025-01-13 13:40:18 -05:00
Wing Lian
cf8174db75 chore: lint 2025-01-13 13:40:18 -05:00
Wing Lian
222dc27410 make sure to multiply against the correct loss 2025-01-13 13:40:18 -05:00
Wing Lian
1107f1f603 cross entropy loss coefficient during KD 2025-01-13 13:40:18 -05:00
Wing Lian
1c603da96a flipped the slice 2025-01-13 13:40:17 -05:00
Wing Lian
283faf3909 make it work 2025-01-13 13:40:17 -05:00
Wing Lian
472f7048e5 handle padding/collation for KD datasets 2025-01-13 13:40:17 -05:00
Wing Lian
3d1e2dcef4 make batch smaller 2025-01-13 13:40:17 -05:00
Wing Lian
9e218fbcfd filter bad rows 2025-01-13 13:40:17 -05:00
Wing Lian
11caf52529 KD dataset loading and KD with logprobs 2025-01-13 13:40:17 -05:00
Wing Lian
17ba9dcfdb refactor trainer to prevent circular dependencies later
fix loader default
2025-01-13 13:40:17 -05:00
17 changed files with 68 additions and 267 deletions

View File

@@ -59,7 +59,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):

View File

@@ -11,7 +11,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
@@ -109,9 +109,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Loads one or more training or evaluation datasets for DPO training, calling
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -121,7 +121,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

View File

@@ -697,12 +697,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_args_cls = (
AxolotlTrainingArguments

View File

@@ -188,13 +188,6 @@ class AxolotlTrainingMixins:
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -31,4 +31,3 @@ class KDArgs(BaseModel):
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling

View File

@@ -52,62 +52,26 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=train_on_eos,
)
@property
def supports_batched(self) -> bool:
# batching doesn't work well for logprob data
return False
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
top_k = len(logprobs[0])
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _ in range(target_seq_len):
# TODO also check against sample["labels"]
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
@@ -127,28 +91,28 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
@@ -156,11 +120,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs

View File

@@ -16,40 +16,6 @@ loss for top_k KL divergence
import torch
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script
def loss(
student_logits: torch.Tensor,
@@ -61,23 +27,8 @@ def loss(
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
Arguments:
student_logits (torch.Tensor): The logits of the student model.
Shape: [B, student_seq_len, vocab_size]
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
Shape: [B, teacher_seq_len, top_k]
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
Shape: [B, teacher_seq_len, top_k]
target_mask (torch.Tensor): The mask for valid tokens.
Shape: [B, teacher_seq_len, top_k]
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
"""
target_logprobs = target_logprobs.float()
# Determine the teacher sequence length
# target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size]
@@ -93,8 +44,6 @@ def loss(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
@@ -131,82 +80,3 @@ def loss(
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
target_logprobs = target_logprobs.float()
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float()
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -19,7 +19,6 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer):
@@ -46,6 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
@@ -69,30 +69,25 @@ class AxolotlKDTrainer(AxolotlTrainer):
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
if shift_targets:
shift_logits = student_logits[..., :-1, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,

View File

@@ -18,13 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
) -> Dict[str, List]:
res = tokenizer(
examples[text_column],
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -199,12 +196,7 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
)
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type")
def load_prepare_preference_datasets(cfg):
def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -1057,7 +1057,7 @@ class ModelLoader:
)
if (
hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings != embeddings_len
and self.model.get_input_embeddings().num_embeddings < embeddings_len
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -279,7 +279,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
@@ -311,7 +310,8 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
labels = sample["labels"]
if not labels:
return True
# Edge case: if labels is empty, decide if you want to keep or drop
return True # or False
# Check if single example or batch
# If first element is an int, we assume a single example

View File

@@ -33,7 +33,6 @@ def min_cfg(temp_dir):
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"),
"datasets": [
{
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",

View File

@@ -4,8 +4,7 @@ E2E tests for llama pretrain
import logging
import os
import pytest
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -13,22 +12,19 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama:
class TestPretrainLlama(unittest.TestCase):
"""
Test case for Llama models w pretraining
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing):
@with_temp_dir
def test_pretrain_w_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -36,7 +32,7 @@ class TestPretrainLlama:
"tokenizer_type": "LlamaTokenizer",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": sample_packing,
"sample_packing": True,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features

View File

@@ -12,7 +12,7 @@ from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
# Verify that the dataset retains duplicates
assert (