kd sample packing
This commit is contained in:
@@ -824,9 +824,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import DataCollatorForKD
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
collator = DataCollatorForKD
|
||||
if self.cfg.sample_packing:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -120,6 +120,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# since we started at index 1 for causal, we need one more padding token
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
|
||||
@@ -186,3 +186,70 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
"""
|
||||
Collator for multipack (batch of sub-batches) specifically for KD.
|
||||
Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
"""
|
||||
Expects that `features` could be either:
|
||||
- a single list of dicts, OR
|
||||
- a list of lists of dicts (the "sub-batches" to be packed).
|
||||
"""
|
||||
# 1) If we are *not* dealing with multiple sequences per batch element,
|
||||
# just pass straight to parent.
|
||||
if not isinstance(features[0], list):
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
# 2) Otherwise, we *are* dealing with multiple sequences in each batch item.
|
||||
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||
out_features = [{} for _ in features]
|
||||
|
||||
for i, sub_features in enumerate(features):
|
||||
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||
# We'll merge them into out_features[i].
|
||||
#
|
||||
# NOTE: You can customize how you combine fields as needed (e.g. summation
|
||||
# or offset for attention_mask). Below is a straightforward concatenation/extension.
|
||||
|
||||
for field_name in sub_features[0].keys():
|
||||
# Some fields you might want to skip or treat specially:
|
||||
if field_name == "length":
|
||||
continue
|
||||
|
||||
# If it’s a KD field that’s a list-of-lists (e.g. target_logprobs),
|
||||
# you typically just want to flatten them by extending.
|
||||
if field_name in ["target_logprobs", "target_token_ids", "target_mask"]:
|
||||
combined = []
|
||||
for feat in sub_features:
|
||||
combined.extend(feat[field_name])
|
||||
out_features[i][field_name] = combined
|
||||
|
||||
elif field_name == "attention_mask":
|
||||
# Here we apply the (j+1) factor to differentiate each sub-sample
|
||||
# within this merged batch item.
|
||||
arrays = []
|
||||
for j, feat in enumerate(sub_features):
|
||||
if field_name in feat:
|
||||
arrays.append((j + 1) * np.array(feat[field_name]))
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
else:
|
||||
# By default, just concatenate them if they are arrays
|
||||
# or extend them if they are lists.
|
||||
# For example, input_ids or labels are often arrays.
|
||||
arrays = []
|
||||
for feat in sub_features:
|
||||
if field_name in feat:
|
||||
arr = np.array(feat[field_name])
|
||||
arrays.append(arr)
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
|
||||
# 3) Now call the parent collator, which will do:
|
||||
# - padding of labels/position_ids
|
||||
# - KD-specific padding for target_logprobs, target_token_ids, etc.
|
||||
# - final conversion to return_tensors
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
Reference in New Issue
Block a user