From e2aba419390abb697ad5fcad788a030e08ab24b9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Dec 2024 18:07:27 -0500 Subject: [PATCH] handle padding/collation for KD datasets --- src/axolotl/core/trainer_builder.py | 4 + src/axolotl/datasets.py | 2 +- .../prompt_strategies/chat_template.py | 8 +- src/axolotl/utils/collators/__init__.py | 1 + src/axolotl/utils/collators/kd.py | 153 ++++++++++++++++++ 5 files changed, 161 insertions(+), 7 deletions(-) create mode 100644 src/axolotl/utils/collators/kd.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 828929dc7..694bed808 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -63,6 +63,7 @@ from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForKD, DataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, @@ -772,6 +773,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): Union[ V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForKD, DataCollatorForSeq2Seq, DataCollatorWithFlattening, RewardDataCollatorWithPadding, @@ -802,6 +804,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) + elif self.cfg.trainer == "kd": + collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 460e8f1bd..005cc41bc 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -56,7 +56,7 @@ class TokenizedPromptDataset(Dataset): dataset = dataset.filter( self.prompt_tokenizer.filter_rows, num_proc=num_proc, - desc="Filtering Rows", + desc="Strategy Filtering Rows", ) return dataset.map( self.prompt_tokenizer.tokenize_prompt, diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 4857d0ffc..5780f8fe3 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -479,12 +479,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): self.logprobs_field = logprobs_field self.temperature = temperature - # remove rows where the logprob field is not available - self.filter_rows = ( - lambda row: self.logprobs_field in row - and row[self.logprobs_field] is not None - ) - super().__init__( prompter, tokenizer, @@ -541,7 +535,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): return sample def tokenize_prompt(self, prompt): + logprobs = prompt.pop(self.logprobs_field) tokenized_prompt = super().tokenize_prompt(prompt) + tokenized_prompt[self.logprobs_field] = logprobs return self.transform_logprobs(tokenized_prompt) diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 93502b67d..649356e12 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -7,4 +7,5 @@ from .batching import ( # noqa: F401 PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) +from .kd import DataCollatorForKD # noqa: F401 from .mamba import MambaDataCollator # noqa: F401 diff --git a/src/axolotl/utils/collators/kd.py b/src/axolotl/utils/collators/kd.py new file mode 100644 index 000000000..a210d221c --- /dev/null +++ b/src/axolotl/utils/collators/kd.py @@ -0,0 +1,153 @@ +""" +DataCollator for axolotl to handle KD fields +""" + +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import torch +from transformers import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + +from axolotl.utils.collators.batching import DataCollatorForSeq2Seq + + +@dataclass +class DataCollatorForKD(DataCollatorForSeq2Seq): + """ + Data collator for KD, including handling KD-specific fields. + """ + + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + position_pad_token_id: int = 0 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + if return_tensors is None: + return_tensors = self.return_tensors + + # Extract labels and position_ids first (as in original code) + for feature_name, pad_token_id in [ + ("labels", self.label_pad_token_id), + ("position_ids", self.position_pad_token_id), + ]: + if feature_name in features[0]: + feat = [f[feature_name] for f in features] + max_len = max(len(x) for x in feat) + if self.pad_to_multiple_of is not None: + max_len = ( + (max_len + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + ) * self.pad_to_multiple_of + + padding_side = self.tokenizer.padding_side + for f in features: # pylint: disable=invalid-name + remainder = [pad_token_id] * (max_len - len(f[feature_name])) + if isinstance(f[feature_name], list): + f[feature_name] = ( + f[feature_name] + remainder + if padding_side == "right" + else remainder + f[feature_name] + ) + else: + # If they are numpy arrays + if padding_side == "right": + f[feature_name] = np.concatenate( + [f[feature_name], remainder] + ).astype(np.int64) + else: + f[feature_name] = np.concatenate( + [remainder, f[feature_name]] + ).astype(np.int64) + + # Handle target_logprobs and target_token_ids manually + target_logprobs_list = [] + target_token_ids_list = [] + has_teacher_data = ("target_logprobs" in features[0]) and ( + "target_token_ids" in features[0] + ) + + if has_teacher_data: + # Extract these fields + for f in features: # pylint: disable=invalid-name + target_logprobs_list.append(f.pop("target_logprobs")) + target_token_ids_list.append(f.pop("target_token_ids")) + + # Determine max lengths to pad + max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) + max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) + + # Pad target_logprobs and target_token_ids + padded_target_logprobs = [] + padded_target_token_ids = [] + for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list): + # Pad seq dimension + t_logprobs_padded = [] + t_ids_padded = [] + for i in range( # pylint: disable=consider-using-enumerate + len(t_logprobs) + ): + lp = t_logprobs[i] # pylint: disable=invalid-name + ids = t_ids[i] + # Pad K dimension + lp_len = len(lp) + if lp_len < max_k: + lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name + max_k - lp_len + ) # or some pad value that won't break exp() + ids = ids + [0] * (max_k - lp_len) + t_logprobs_padded.append(lp) + t_ids_padded.append(ids) + + # If sequence is shorter than max_teacher_seq_len + seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded) + if seq_len_diff > 0: + t_logprobs_padded.extend( + [[-float("inf")] * max_k for _ in range(seq_len_diff)] + ) + t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)]) + + padded_target_logprobs.append(t_logprobs_padded) + padded_target_token_ids.append(t_ids_padded) + + # Convert to tensors + padded_target_logprobs = torch.tensor( + padded_target_logprobs, dtype=torch.float + ) + # We can store token_ids as long tensor + padded_target_token_ids = torch.tensor( + padded_target_token_ids, dtype=torch.long + ) + + # Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.) + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # Add back the teacher data if it exists + if has_teacher_data: + features["target_logprobs"] = padded_target_logprobs + features["target_token_ids"] = padded_target_token_ids + + # Prepare decoder_input_ids if applicable + if ( + "labels" in features + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + + return features