Compare commits
3 Commits
smol-ci
...
torch-211-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98e18d59d2 | ||
|
|
462135acfb | ||
|
|
901f2356bc |
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -224,6 +224,22 @@ jobs:
|
|||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.12"
|
||||||
|
pytorch: 2.11.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "130"
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.12"
|
||||||
|
pytorch: 2.11.0
|
||||||
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
}
|
}
|
||||||
multiple = 64
|
multiple = getattr(self.cfg, "pad_to_multiple_of", None) or 64
|
||||||
if self.cfg.pad_to_sequence_len:
|
if self.cfg.pad_to_sequence_len:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
|
|||||||
@@ -228,9 +228,47 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
return training_args, trainer_kwargs
|
return training_args, trainer_kwargs
|
||||||
|
|
||||||
|
def build_collator(self, **kwargs):
|
||||||
|
"""Build a data collator for preference-tuning trainers.
|
||||||
|
|
||||||
|
Returns None for RL types that provide their own collator (e.g. GRPO,
|
||||||
|
KTO), letting the trainer construct its default. For DPO/IPO/ORPO/SIMPO
|
||||||
|
returns an ``AxolotlDPODataCollatorWithPadding`` when
|
||||||
|
``pad_to_multiple_of`` is set, otherwise None (so the trainer
|
||||||
|
falls back to the TRL default).
|
||||||
|
"""
|
||||||
|
if self.cfg.rl not in (
|
||||||
|
RLType.DPO,
|
||||||
|
RLType.IPO,
|
||||||
|
RLType.ORPO,
|
||||||
|
RLType.SIMPO,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
pad_to_multiple_of = getattr(self.cfg, "pad_to_multiple_of", None)
|
||||||
|
if not pad_to_multiple_of:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from axolotl.utils.collators.dpo import AxolotlDPODataCollatorWithPadding
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Using AxolotlDPODataCollatorWithPadding with pad_to_multiple_of="
|
||||||
|
f"{pad_to_multiple_of}"
|
||||||
|
)
|
||||||
|
is_enc_dec = getattr(self.model.config, "is_encoder_decoder", False)
|
||||||
|
return AxolotlDPODataCollatorWithPadding(
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
is_encoder_decoder=is_enc_dec,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
||||||
|
|
||||||
|
if (data_collator := self.build_collator()) is not None:
|
||||||
|
trainer_kwargs["data_collator"] = data_collator
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -407,7 +407,10 @@ def selective_log_softmax(logits, index) -> torch.Tensor:
|
|||||||
K = index.shape[-1]
|
K = index.shape[-1]
|
||||||
original_index_shape = index.shape
|
original_index_shape = index.shape
|
||||||
|
|
||||||
flat_logits = logits.reshape(-1, V).contiguous()
|
try:
|
||||||
|
flat_logits = logits.view(-1, V)
|
||||||
|
except RuntimeError:
|
||||||
|
flat_logits = logits.reshape(-1, V).contiguous()
|
||||||
flat_index = index.reshape(-1, K).contiguous()
|
flat_index = index.reshape(-1, K).contiguous()
|
||||||
|
|
||||||
BLOCK_V = 4096
|
BLOCK_V = 4096
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .batching import (
|
|||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from .dpo import AxolotlDPODataCollatorWithPadding
|
||||||
from .mamba import MambaDataCollator
|
from .mamba import MambaDataCollator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -13,5 +14,6 @@ __all__ = [
|
|||||||
"BatchSamplerDataCollatorForSeq2Seq",
|
"BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"V2BatchSamplerDataCollatorForSeq2Seq",
|
"V2BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
||||||
|
"AxolotlDPODataCollatorWithPadding",
|
||||||
"MambaDataCollator",
|
"MambaDataCollator",
|
||||||
]
|
]
|
||||||
|
|||||||
128
src/axolotl/utils/collators/dpo.py
Normal file
128
src/axolotl/utils/collators/dpo.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""DPO/ORPO/IPO/KTO data collator with pad_to_multiple_of support.
|
||||||
|
|
||||||
|
Extends TRL's DPODataCollatorWithPadding to round padded sequence lengths
|
||||||
|
up to a fixed multiple. This stabilizes Triton autotune caches for kernels
|
||||||
|
that key on sequence length (e.g. fla's linear attention kernels used by
|
||||||
|
Qwen3.5), which otherwise re-autotune on every distinct batch length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from trl.experimental.utils import DPODataCollatorWithPadding
|
||||||
|
from trl.trainer.utils import pad
|
||||||
|
|
||||||
|
|
||||||
|
def _round_up(length: int, multiple: int) -> int:
|
||||||
|
return ((length + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPODataCollatorWithPadding(DPODataCollatorWithPadding):
|
||||||
|
"""DPO data collator that pads to a multiple of ``pad_to_multiple_of``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_token_id: Tokenizer pad token id (inherited).
|
||||||
|
is_encoder_decoder: Whether the model is encoder-decoder (inherited).
|
||||||
|
pad_to_multiple_of: If set, padded lengths are rounded up to this
|
||||||
|
multiple. Helps stabilize Triton autotune caches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pad_to_multiple_of: int | None = None
|
||||||
|
|
||||||
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
pad_to_mult = self.pad_to_multiple_of
|
||||||
|
|
||||||
|
padded_batch: dict[str, Any] = {}
|
||||||
|
for k in features[0].keys():
|
||||||
|
if k.endswith(
|
||||||
|
("_input_ids", "_attention_mask", "_labels", "_pixel_values")
|
||||||
|
):
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
if k.endswith("_pixel_values"):
|
||||||
|
to_pad = [
|
||||||
|
torch.tensor(ex[k], dtype=torch.float32) for ex in features
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
||||||
|
|
||||||
|
if k.startswith("prompt") and k.endswith("input_ids"):
|
||||||
|
if self.pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||||
|
)
|
||||||
|
padding_value = self.pad_token_id
|
||||||
|
elif k.endswith("_attention_mask"):
|
||||||
|
padding_value = 0
|
||||||
|
elif k.endswith("_pixel_values"):
|
||||||
|
padding_value = 0
|
||||||
|
elif (
|
||||||
|
k.startswith(("chosen", "rejected", "completion"))
|
||||||
|
or "decoder" in k
|
||||||
|
):
|
||||||
|
padding_value = -100
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||||
|
|
||||||
|
padded = pad_sequence(
|
||||||
|
to_pad, batch_first=True, padding_value=padding_value
|
||||||
|
)
|
||||||
|
if pad_to_mult:
|
||||||
|
cur = padded.shape[1]
|
||||||
|
target = _round_up(cur, pad_to_mult)
|
||||||
|
if target > cur:
|
||||||
|
extra = target - cur
|
||||||
|
pad_shape = list(padded.shape)
|
||||||
|
pad_shape[1] = extra
|
||||||
|
filler = torch.full(
|
||||||
|
pad_shape,
|
||||||
|
padding_value,
|
||||||
|
dtype=padded.dtype,
|
||||||
|
device=padded.device,
|
||||||
|
)
|
||||||
|
padded = torch.cat([padded, filler], dim=1)
|
||||||
|
padded_batch[k] = padded
|
||||||
|
else:
|
||||||
|
if k.endswith("_input_ids"):
|
||||||
|
if self.pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||||
|
)
|
||||||
|
padding_value = self.pad_token_id
|
||||||
|
elif k.endswith("_labels"):
|
||||||
|
padding_value = -100
|
||||||
|
elif k.endswith("_attention_mask"):
|
||||||
|
padding_value = 0
|
||||||
|
elif k.endswith("_pixel_values"):
|
||||||
|
padding_value = 0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||||
|
|
||||||
|
padding_side = (
|
||||||
|
"left"
|
||||||
|
if k in ("prompt_input_ids", "prompt_attention_mask")
|
||||||
|
else "right"
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = (
|
||||||
|
torch.float32 if k.endswith("_pixel_values") else torch.int64
|
||||||
|
)
|
||||||
|
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
|
||||||
|
|
||||||
|
# trl.pad() natively supports pad_to_multiple_of
|
||||||
|
padded_batch[k] = pad(
|
||||||
|
to_pad,
|
||||||
|
padding_value=padding_value,
|
||||||
|
padding_side=padding_side,
|
||||||
|
pad_to_multiple_of=pad_to_mult,
|
||||||
|
)
|
||||||
|
elif k.endswith("_logps"):
|
||||||
|
padded_batch[k] = torch.tensor([ex[k] for ex in features])
|
||||||
|
else:
|
||||||
|
padded_batch[k] = [ex[k] for ex in features]
|
||||||
|
|
||||||
|
return padded_batch
|
||||||
@@ -673,6 +673,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
pad_to_multiple_of: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": ("Pad each batch to a multiple of this value.")
|
||||||
|
},
|
||||||
|
)
|
||||||
curriculum_sampling: bool | None = Field(
|
curriculum_sampling: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user