handle padding/collation for KD datasets

This commit is contained in:
Wing Lian
2024-12-18 18:07:27 -05:00
parent 21caaaa2e9
commit e2aba41939
5 changed files with 161 additions and 7 deletions

View File

@@ -63,6 +63,7 @@ from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForKD,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
@@ -772,6 +773,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForKD,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
@@ -802,6 +804,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.trainer == "kd":
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq

View File

@@ -56,7 +56,7 @@ class TokenizedPromptDataset(Dataset):
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=num_proc,
desc="Filtering Rows",
desc="Strategy Filtering Rows",
)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,

View File

@@ -479,12 +479,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
self.logprobs_field = logprobs_field
self.temperature = temperature
# remove rows where the logprob field is not available
self.filter_rows = (
lambda row: self.logprobs_field in row
and row[self.logprobs_field] is not None
)
super().__init__(
prompter,
tokenizer,
@@ -541,7 +535,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
return sample
def tokenize_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
return self.transform_logprobs(tokenized_prompt)

View File

@@ -7,4 +7,5 @@ from .batching import ( # noqa: F401
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .kd import DataCollatorForKD # noqa: F401
from .mamba import MambaDataCollator # noqa: F401

View File

@@ -0,0 +1,153 @@
"""
DataCollator for axolotl to handle KD fields
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
@dataclass
class DataCollatorForKD(DataCollatorForSeq2Seq):
"""
Data collator for KD, including handling KD-specific fields.
"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
# Extract labels and position_ids first (as in original code)
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
]:
if feature_name in features[0]:
feat = [f[feature_name] for f in features]
max_len = max(len(x) for x in feat)
if self.pad_to_multiple_of is not None:
max_len = (
(max_len + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
) * self.pad_to_multiple_of
padding_side = self.tokenizer.padding_side
for f in features: # pylint: disable=invalid-name
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list):
f[feature_name] = (
f[feature_name] + remainder
if padding_side == "right"
else remainder + f[feature_name]
)
else:
# If they are numpy arrays
if padding_side == "right":
f[feature_name] = np.concatenate(
[f[feature_name], remainder]
).astype(np.int64)
else:
f[feature_name] = np.concatenate(
[remainder, f[feature_name]]
).astype(np.int64)
# Handle target_logprobs and target_token_ids manually
target_logprobs_list = []
target_token_ids_list = []
has_teacher_data = ("target_logprobs" in features[0]) and (
"target_token_ids" in features[0]
)
if has_teacher_data:
# Extract these fields
for f in features: # pylint: disable=invalid-name
target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids"))
# Determine max lengths to pad
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
# Pad target_logprobs and target_token_ids
padded_target_logprobs = []
padded_target_token_ids = []
for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list):
# Pad seq dimension
t_logprobs_padded = []
t_ids_padded = []
for i in range( # pylint: disable=consider-using-enumerate
len(t_logprobs)
):
lp = t_logprobs[i] # pylint: disable=invalid-name
ids = t_ids[i]
# Pad K dimension
lp_len = len(lp)
if lp_len < max_k:
lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name
max_k - lp_len
) # or some pad value that won't break exp()
ids = ids + [0] * (max_k - lp_len)
t_logprobs_padded.append(lp)
t_ids_padded.append(ids)
# If sequence is shorter than max_teacher_seq_len
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
if seq_len_diff > 0:
t_logprobs_padded.extend(
[[-float("inf")] * max_k for _ in range(seq_len_diff)]
)
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
padded_target_logprobs.append(t_logprobs_padded)
padded_target_token_ids.append(t_ids_padded)
# Convert to tensors
padded_target_logprobs = torch.tensor(
padded_target_logprobs, dtype=torch.float
)
# We can store token_ids as long tensor
padded_target_token_ids = torch.tensor(
padded_target_token_ids, dtype=torch.long
)
# Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.)
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# Add back the teacher data if it exists
if has_teacher_data:
features["target_logprobs"] = padded_target_logprobs
features["target_token_ids"] = padded_target_token_ids
# Prepare decoder_input_ids if applicable
if (
"labels" in features
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=features["labels"]
)
features["decoder_input_ids"] = decoder_input_ids
return features