kd sample packing
This commit is contained in:
@@ -824,9 +824,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
kwargs.pop("pad_to_multiple_of", None)
|
kwargs.pop("pad_to_multiple_of", None)
|
||||||
kwargs.pop("padding", None)
|
kwargs.pop("padding", None)
|
||||||
elif self.cfg.kd_trainer:
|
elif self.cfg.kd_trainer:
|
||||||
from axolotl.integrations.kd.collator import DataCollatorForKD
|
from axolotl.integrations.kd.collator import (
|
||||||
|
DataCollatorForKD,
|
||||||
|
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
)
|
||||||
|
|
||||||
collator = DataCollatorForKD
|
if self.cfg.sample_packing:
|
||||||
|
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
else:
|
||||||
|
collator = DataCollatorForKD
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|||||||
@@ -120,6 +120,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_logprobs.append(position_logprobs_scaled)
|
target_logprobs.append(position_logprobs_scaled)
|
||||||
target_token_ids.append(position_token_ids)
|
target_token_ids.append(position_token_ids)
|
||||||
|
|
||||||
|
# since we started at index 1 for causal, we need one more padding token
|
||||||
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
|
target_token_ids.append(list(range(top_k)))
|
||||||
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
# Update sample with transformed logprobs
|
# Update sample with transformed logprobs
|
||||||
sample["target_logprobs"] = target_logprobs
|
sample["target_logprobs"] = target_logprobs
|
||||||
sample["target_token_ids"] = target_token_ids
|
sample["target_token_ids"] = target_token_ids
|
||||||
|
|||||||
@@ -186,3 +186,70 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||||
|
"""
|
||||||
|
Collator for multipack (batch of sub-batches) specifically for KD.
|
||||||
|
Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
"""
|
||||||
|
Expects that `features` could be either:
|
||||||
|
- a single list of dicts, OR
|
||||||
|
- a list of lists of dicts (the "sub-batches" to be packed).
|
||||||
|
"""
|
||||||
|
# 1) If we are *not* dealing with multiple sequences per batch element,
|
||||||
|
# just pass straight to parent.
|
||||||
|
if not isinstance(features[0], list):
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
# 2) Otherwise, we *are* dealing with multiple sequences in each batch item.
|
||||||
|
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||||
|
out_features = [{} for _ in features]
|
||||||
|
|
||||||
|
for i, sub_features in enumerate(features):
|
||||||
|
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||||
|
# We'll merge them into out_features[i].
|
||||||
|
#
|
||||||
|
# NOTE: You can customize how you combine fields as needed (e.g. summation
|
||||||
|
# or offset for attention_mask). Below is a straightforward concatenation/extension.
|
||||||
|
|
||||||
|
for field_name in sub_features[0].keys():
|
||||||
|
# Some fields you might want to skip or treat specially:
|
||||||
|
if field_name == "length":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If it’s a KD field that’s a list-of-lists (e.g. target_logprobs),
|
||||||
|
# you typically just want to flatten them by extending.
|
||||||
|
if field_name in ["target_logprobs", "target_token_ids", "target_mask"]:
|
||||||
|
combined = []
|
||||||
|
for feat in sub_features:
|
||||||
|
combined.extend(feat[field_name])
|
||||||
|
out_features[i][field_name] = combined
|
||||||
|
|
||||||
|
elif field_name == "attention_mask":
|
||||||
|
# Here we apply the (j+1) factor to differentiate each sub-sample
|
||||||
|
# within this merged batch item.
|
||||||
|
arrays = []
|
||||||
|
for j, feat in enumerate(sub_features):
|
||||||
|
if field_name in feat:
|
||||||
|
arrays.append((j + 1) * np.array(feat[field_name]))
|
||||||
|
out_features[i][field_name] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
# By default, just concatenate them if they are arrays
|
||||||
|
# or extend them if they are lists.
|
||||||
|
# For example, input_ids or labels are often arrays.
|
||||||
|
arrays = []
|
||||||
|
for feat in sub_features:
|
||||||
|
if field_name in feat:
|
||||||
|
arr = np.array(feat[field_name])
|
||||||
|
arrays.append(arr)
|
||||||
|
out_features[i][field_name] = np.concatenate(arrays)
|
||||||
|
|
||||||
|
# 3) Now call the parent collator, which will do:
|
||||||
|
# - padding of labels/position_ids
|
||||||
|
# - KD-specific padding for target_logprobs, target_token_ids, etc.
|
||||||
|
# - final conversion to return_tensors
|
||||||
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|||||||
Reference in New Issue
Block a user