wire up squash_position_ids
This commit is contained in:
@@ -476,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
if self.cfg.squash_position_ids:
|
||||
kwargs["squash_position_ids"] = True
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
|
||||
@@ -459,6 +459,12 @@ class AxolotlInputConfig(
|
||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||
},
|
||||
)
|
||||
squash_position_ids: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to squash position_ids for packing, effectively extending context length."
|
||||
},
|
||||
)
|
||||
eval_sample_packing: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
Reference in New Issue
Block a user