From d3c2b7ce9dfc9a8658f1933045a724d4261b93fa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 20:10:47 -0500 Subject: [PATCH] kd sample packing --- src/axolotl/core/trainer_builder.py | 10 ++- src/axolotl/integrations/kd/chat_template.py | 5 ++ src/axolotl/integrations/kd/collator.py | 67 ++++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e78b16e65..8be180c95 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -824,9 +824,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) elif self.cfg.kd_trainer: - from axolotl.integrations.kd.collator import DataCollatorForKD + from axolotl.integrations.kd.collator import ( + DataCollatorForKD, + KDBatchSamplerDataCollatorForSeq2Seq, + ) - collator = DataCollatorForKD + if self.cfg.sample_packing: + collator = KDBatchSamplerDataCollatorForSeq2Seq + else: + collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 1b031c490..67402a033 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -120,6 +120,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_logprobs.append(position_logprobs_scaled) target_token_ids.append(position_token_ids) + # since we started at index 1 for causal, we need one more padding token + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs sample["target_token_ids"] = target_token_ids diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index b63f8f971..de63869c7 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -186,3 +186,70 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): features["decoder_input_ids"] = decoder_input_ids return features + + +class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): + """ + Collator for multipack (batch of sub-batches) specifically for KD. + Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item. + """ + + def __call__(self, features, return_tensors=None): + """ + Expects that `features` could be either: + - a single list of dicts, OR + - a list of lists of dicts (the "sub-batches" to be packed). + """ + # 1) If we are *not* dealing with multiple sequences per batch element, + # just pass straight to parent. + if not isinstance(features[0], list): + return super().__call__(features, return_tensors=return_tensors) + + # 2) Otherwise, we *are* dealing with multiple sequences in each batch item. + # We want to produce a single "merged" feature dict for each sub-batch. + out_features = [{} for _ in features] + + for i, sub_features in enumerate(features): + # sub_features is a list of dicts, each dict = one sequence’s features + # We'll merge them into out_features[i]. + # + # NOTE: You can customize how you combine fields as needed (e.g. summation + # or offset for attention_mask). Below is a straightforward concatenation/extension. + + for field_name in sub_features[0].keys(): + # Some fields you might want to skip or treat specially: + if field_name == "length": + continue + + # If it’s a KD field that’s a list-of-lists (e.g. target_logprobs), + # you typically just want to flatten them by extending. + if field_name in ["target_logprobs", "target_token_ids", "target_mask"]: + combined = [] + for feat in sub_features: + combined.extend(feat[field_name]) + out_features[i][field_name] = combined + + elif field_name == "attention_mask": + # Here we apply the (j+1) factor to differentiate each sub-sample + # within this merged batch item. + arrays = [] + for j, feat in enumerate(sub_features): + if field_name in feat: + arrays.append((j + 1) * np.array(feat[field_name])) + out_features[i][field_name] = np.concatenate(arrays) + else: + # By default, just concatenate them if they are arrays + # or extend them if they are lists. + # For example, input_ids or labels are often arrays. + arrays = [] + for feat in sub_features: + if field_name in feat: + arr = np.array(feat[field_name]) + arrays.append(arr) + out_features[i][field_name] = np.concatenate(arrays) + + # 3) Now call the parent collator, which will do: + # - padding of labels/position_ids + # - KD-specific padding for target_logprobs, target_token_ids, etc. + # - final conversion to return_tensors + return super().__call__(out_features, return_tensors=return_tensors)