use DataCollatorWithFlattening when not sample packing (#2167)
This commit is contained in:
@@ -28,6 +28,7 @@ from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
@@ -1989,9 +1990,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorWithFlattening,
|
||||
RewardDataCollatorWithPadding,
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
if "max_length" in kwargs:
|
||||
@@ -2011,12 +2014,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = MultiModalChatDataCollator
|
||||
kwargs["processor"] = self.processor
|
||||
kwargs["chat_template"] = training_args.chat_template
|
||||
elif self.cfg.batch_flattening:
|
||||
collator = DataCollatorWithFlattening
|
||||
collator_args.pop(0)
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
|
||||
return collator(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
*collator_args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -696,6 +696,8 @@ class AxolotlInputConfig(
|
||||
curriculum_sampling: Optional[bool] = None
|
||||
multipack_real_batches: Optional[bool] = None
|
||||
|
||||
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
|
||||
|
||||
# for PoSE context length extension
|
||||
use_pose: Optional[bool] = None
|
||||
pose_split_on_token_ids: Optional[List[int]] = None
|
||||
@@ -924,6 +926,30 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
if data.get("batch_flattening"):
|
||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||
if not data.get("flash_attention") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening requires flash attention")
|
||||
if data.get("sample_packing") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
|
||||
|
||||
if (
|
||||
batch_flattening_auto
|
||||
and data.get("flash_attention")
|
||||
and not data.get("sample_packing")
|
||||
and data.get("micro_batch_size") > 1
|
||||
):
|
||||
data["batch_flattening"] = True
|
||||
elif batch_flattening_auto:
|
||||
data["batch_flattening"] = False
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_rl(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user