handle padding/collation for KD datasets
This commit is contained in:
@@ -63,6 +63,7 @@ from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
|||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
DataCollatorForKD,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
@@ -772,6 +773,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
Union[
|
Union[
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
DataCollatorForKD,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
RewardDataCollatorWithPadding,
|
||||||
@@ -802,6 +804,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
kwargs.pop("pad_to_multiple_of", None)
|
kwargs.pop("pad_to_multiple_of", None)
|
||||||
kwargs.pop("padding", None)
|
kwargs.pop("padding", None)
|
||||||
|
elif self.cfg.trainer == "kd":
|
||||||
|
collator = DataCollatorForKD
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
desc="Filtering Rows",
|
desc="Strategy Filtering Rows",
|
||||||
)
|
)
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
|
|||||||
@@ -479,12 +479,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
self.logprobs_field = logprobs_field
|
self.logprobs_field = logprobs_field
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
# remove rows where the logprob field is not available
|
|
||||||
self.filter_rows = (
|
|
||||||
lambda row: self.logprobs_field in row
|
|
||||||
and row[self.logprobs_field] is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
prompter,
|
prompter,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -541,7 +535,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
tokenized_prompt = super().tokenize_prompt(prompt)
|
||||||
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
return self.transform_logprobs(tokenized_prompt)
|
return self.transform_logprobs(tokenized_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ from .batching import ( # noqa: F401
|
|||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from .kd import DataCollatorForKD # noqa: F401
|
||||||
from .mamba import MambaDataCollator # noqa: F401
|
from .mamba import MambaDataCollator # noqa: F401
|
||||||
|
|||||||
153
src/axolotl/utils/collators/kd.py
Normal file
153
src/axolotl/utils/collators/kd.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
DataCollator for axolotl to handle KD fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Data collator for KD, including handling KD-specific fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizerBase
|
||||||
|
model: Optional[Any] = None
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
label_pad_token_id: int = -100
|
||||||
|
position_pad_token_id: int = 0
|
||||||
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
if return_tensors is None:
|
||||||
|
return_tensors = self.return_tensors
|
||||||
|
|
||||||
|
# Extract labels and position_ids first (as in original code)
|
||||||
|
for feature_name, pad_token_id in [
|
||||||
|
("labels", self.label_pad_token_id),
|
||||||
|
("position_ids", self.position_pad_token_id),
|
||||||
|
]:
|
||||||
|
if feature_name in features[0]:
|
||||||
|
feat = [f[feature_name] for f in features]
|
||||||
|
max_len = max(len(x) for x in feat)
|
||||||
|
if self.pad_to_multiple_of is not None:
|
||||||
|
max_len = (
|
||||||
|
(max_len + self.pad_to_multiple_of - 1)
|
||||||
|
// self.pad_to_multiple_of
|
||||||
|
) * self.pad_to_multiple_of
|
||||||
|
|
||||||
|
padding_side = self.tokenizer.padding_side
|
||||||
|
for f in features: # pylint: disable=invalid-name
|
||||||
|
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
|
||||||
|
if isinstance(f[feature_name], list):
|
||||||
|
f[feature_name] = (
|
||||||
|
f[feature_name] + remainder
|
||||||
|
if padding_side == "right"
|
||||||
|
else remainder + f[feature_name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If they are numpy arrays
|
||||||
|
if padding_side == "right":
|
||||||
|
f[feature_name] = np.concatenate(
|
||||||
|
[f[feature_name], remainder]
|
||||||
|
).astype(np.int64)
|
||||||
|
else:
|
||||||
|
f[feature_name] = np.concatenate(
|
||||||
|
[remainder, f[feature_name]]
|
||||||
|
).astype(np.int64)
|
||||||
|
|
||||||
|
# Handle target_logprobs and target_token_ids manually
|
||||||
|
target_logprobs_list = []
|
||||||
|
target_token_ids_list = []
|
||||||
|
has_teacher_data = ("target_logprobs" in features[0]) and (
|
||||||
|
"target_token_ids" in features[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_teacher_data:
|
||||||
|
# Extract these fields
|
||||||
|
for f in features: # pylint: disable=invalid-name
|
||||||
|
target_logprobs_list.append(f.pop("target_logprobs"))
|
||||||
|
target_token_ids_list.append(f.pop("target_token_ids"))
|
||||||
|
|
||||||
|
# Determine max lengths to pad
|
||||||
|
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||||
|
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||||
|
|
||||||
|
# Pad target_logprobs and target_token_ids
|
||||||
|
padded_target_logprobs = []
|
||||||
|
padded_target_token_ids = []
|
||||||
|
for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list):
|
||||||
|
# Pad seq dimension
|
||||||
|
t_logprobs_padded = []
|
||||||
|
t_ids_padded = []
|
||||||
|
for i in range( # pylint: disable=consider-using-enumerate
|
||||||
|
len(t_logprobs)
|
||||||
|
):
|
||||||
|
lp = t_logprobs[i] # pylint: disable=invalid-name
|
||||||
|
ids = t_ids[i]
|
||||||
|
# Pad K dimension
|
||||||
|
lp_len = len(lp)
|
||||||
|
if lp_len < max_k:
|
||||||
|
lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name
|
||||||
|
max_k - lp_len
|
||||||
|
) # or some pad value that won't break exp()
|
||||||
|
ids = ids + [0] * (max_k - lp_len)
|
||||||
|
t_logprobs_padded.append(lp)
|
||||||
|
t_ids_padded.append(ids)
|
||||||
|
|
||||||
|
# If sequence is shorter than max_teacher_seq_len
|
||||||
|
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
|
||||||
|
if seq_len_diff > 0:
|
||||||
|
t_logprobs_padded.extend(
|
||||||
|
[[-float("inf")] * max_k for _ in range(seq_len_diff)]
|
||||||
|
)
|
||||||
|
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||||
|
|
||||||
|
padded_target_logprobs.append(t_logprobs_padded)
|
||||||
|
padded_target_token_ids.append(t_ids_padded)
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
|
padded_target_logprobs = torch.tensor(
|
||||||
|
padded_target_logprobs, dtype=torch.float
|
||||||
|
)
|
||||||
|
# We can store token_ids as long tensor
|
||||||
|
padded_target_token_ids = torch.tensor(
|
||||||
|
padded_target_token_ids, dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.)
|
||||||
|
features = self.tokenizer.pad(
|
||||||
|
features,
|
||||||
|
padding=self.padding,
|
||||||
|
max_length=self.max_length,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add back the teacher data if it exists
|
||||||
|
if has_teacher_data:
|
||||||
|
features["target_logprobs"] = padded_target_logprobs
|
||||||
|
features["target_token_ids"] = padded_target_token_ids
|
||||||
|
|
||||||
|
# Prepare decoder_input_ids if applicable
|
||||||
|
if (
|
||||||
|
"labels" in features
|
||||||
|
and self.model is not None
|
||||||
|
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||||
|
):
|
||||||
|
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||||
|
labels=features["labels"]
|
||||||
|
)
|
||||||
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
return features
|
||||||
Reference in New Issue
Block a user