Compare commits
1 Commits
fix/diffus
...
squash_pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ba1cd3f1 |
@@ -476,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
if self.cfg.squash_position_ids:
|
||||||
|
kwargs["squash_position_ids"] = True
|
||||||
else:
|
else:
|
||||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -459,6 +459,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
squash_position_ids: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to squash position_ids for packing, effectively extending context length."
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user