diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e5bc21762..a3af94ca7 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -476,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) ): collator = V2BatchSamplerDataCollatorForSeq2Seq + if self.cfg.squash_position_ids: + kwargs["squash_position_ids"] = True else: collator = BatchSamplerDataCollatorForSeq2Seq else: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index a607b3dca..21083ed87 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -459,6 +459,12 @@ class AxolotlInputConfig( "description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'" }, ) + squash_position_ids: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to squash position_ids for packing, effectively extending context length." + }, + ) eval_sample_packing: bool | None = Field( default=None, json_schema_extra={