diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index f26ef8969..fe832dd45 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -370,7 +370,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): data_collator_kwargs = { "padding": True, # True/"longest" is the default } - multiple = 64 + multiple = getattr(self.cfg, "pad_to_multiple_of", None) or 64 if self.cfg.pad_to_sequence_len: data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil( self.cfg.sequence_len / multiple diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 447b64eb8..98296a201 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -228,9 +228,47 @@ class HFRLTrainerBuilder(TrainerBuilderBase): return training_args, trainer_kwargs + def build_collator(self, **kwargs): + """Build a data collator for preference-tuning trainers. + + Returns None for RL types that provide their own collator (e.g. GRPO, + KTO), letting the trainer construct its default. For DPO/IPO/ORPO/SIMPO + returns an ``AxolotlDPODataCollatorWithPadding`` when + ``pad_to_multiple_of`` is set, otherwise None (so the trainer + falls back to the TRL default). + """ + if self.cfg.rl not in ( + RLType.DPO, + RLType.IPO, + RLType.ORPO, + RLType.SIMPO, + ): + return None + + pad_to_multiple_of = getattr(self.cfg, "pad_to_multiple_of", None) + if not pad_to_multiple_of: + return None + + from axolotl.utils.collators.dpo import AxolotlDPODataCollatorWithPadding + + LOG.info( + f"Using AxolotlDPODataCollatorWithPadding with pad_to_multiple_of=" + f"{pad_to_multiple_of}" + ) + is_enc_dec = getattr(self.model.config, "is_encoder_decoder", False) + return AxolotlDPODataCollatorWithPadding( + pad_token_id=self.tokenizer.pad_token_id, + is_encoder_decoder=is_enc_dec, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + def build(self, total_num_steps): training_args, trainer_kwargs = self._build_training_arguments(total_num_steps) + if (data_collator := self.build_collator()) is not None: + trainer_kwargs["data_collator"] = data_collator + if self.eval_dataset: trainer_kwargs["eval_dataset"] = self.eval_dataset if ( diff --git a/src/axolotl/monkeypatch/trainer/utils.py b/src/axolotl/monkeypatch/trainer/utils.py index 467f50a5a..051eda270 100644 --- a/src/axolotl/monkeypatch/trainer/utils.py +++ b/src/axolotl/monkeypatch/trainer/utils.py @@ -407,7 +407,10 @@ def selective_log_softmax(logits, index) -> torch.Tensor: K = index.shape[-1] original_index_shape = index.shape - flat_logits = logits.reshape(-1, V).contiguous() + try: + flat_logits = logits.view(-1, V) + except RuntimeError: + flat_logits = logits.reshape(-1, V).contiguous() flat_index = index.reshape(-1, K).contiguous() BLOCK_V = 4096 diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index d5e6ad17d..fdc030a2c 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -6,6 +6,7 @@ from .batching import ( PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) +from .dpo import AxolotlDPODataCollatorWithPadding from .mamba import MambaDataCollator __all__ = [ @@ -13,5 +14,6 @@ __all__ = [ "BatchSamplerDataCollatorForSeq2Seq", "V2BatchSamplerDataCollatorForSeq2Seq", "PretrainingBatchSamplerDataCollatorForSeq2Seq", + "AxolotlDPODataCollatorWithPadding", "MambaDataCollator", ] diff --git a/src/axolotl/utils/collators/dpo.py b/src/axolotl/utils/collators/dpo.py new file mode 100644 index 000000000..6f10188c0 --- /dev/null +++ b/src/axolotl/utils/collators/dpo.py @@ -0,0 +1,128 @@ +"""DPO/ORPO/IPO/KTO data collator with pad_to_multiple_of support. + +Extends TRL's DPODataCollatorWithPadding to round padded sequence lengths +up to a fixed multiple. This stabilizes Triton autotune caches for kernels +that key on sequence length (e.g. fla's linear attention kernels used by +Qwen3.5), which otherwise re-autotune on every distinct batch length. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +from torch.nn.utils.rnn import pad_sequence +from trl.experimental.utils import DPODataCollatorWithPadding +from trl.trainer.utils import pad + + +def _round_up(length: int, multiple: int) -> int: + return ((length + multiple - 1) // multiple) * multiple + + +@dataclass +class AxolotlDPODataCollatorWithPadding(DPODataCollatorWithPadding): + """DPO data collator that pads to a multiple of ``pad_to_multiple_of``. + + Args: + pad_token_id: Tokenizer pad token id (inherited). + is_encoder_decoder: Whether the model is encoder-decoder (inherited). + pad_to_multiple_of: If set, padded lengths are rounded up to this + multiple. Helps stabilize Triton autotune caches. + """ + + pad_to_multiple_of: int | None = None + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + pad_to_mult = self.pad_to_multiple_of + + padded_batch: dict[str, Any] = {} + for k in features[0].keys(): + if k.endswith( + ("_input_ids", "_attention_mask", "_labels", "_pixel_values") + ): + if self.is_encoder_decoder: + if k.endswith("_pixel_values"): + to_pad = [ + torch.tensor(ex[k], dtype=torch.float32) for ex in features + ] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in features] + + if k.startswith("prompt") and k.endswith("input_ids"): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." + ) + padding_value = self.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif k.endswith("_pixel_values"): + padding_value = 0 + elif ( + k.startswith(("chosen", "rejected", "completion")) + or "decoder" in k + ): + padding_value = -100 + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded = pad_sequence( + to_pad, batch_first=True, padding_value=padding_value + ) + if pad_to_mult: + cur = padded.shape[1] + target = _round_up(cur, pad_to_mult) + if target > cur: + extra = target - cur + pad_shape = list(padded.shape) + pad_shape[1] = extra + filler = torch.full( + pad_shape, + padding_value, + dtype=padded.dtype, + device=padded.device, + ) + padded = torch.cat([padded, filler], dim=1) + padded_batch[k] = padded + else: + if k.endswith("_input_ids"): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." + ) + padding_value = self.pad_token_id + elif k.endswith("_labels"): + padding_value = -100 + elif k.endswith("_attention_mask"): + padding_value = 0 + elif k.endswith("_pixel_values"): + padding_value = 0 + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padding_side = ( + "left" + if k in ("prompt_input_ids", "prompt_attention_mask") + else "right" + ) + + dtype = ( + torch.float32 if k.endswith("_pixel_values") else torch.int64 + ) + to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features] + + # trl.pad() natively supports pad_to_multiple_of + padded_batch[k] = pad( + to_pad, + padding_value=padding_value, + padding_side=padding_side, + pad_to_multiple_of=pad_to_mult, + ) + elif k.endswith("_logps"): + padded_batch[k] = torch.tensor([ex[k] for ex in features]) + else: + padded_batch[k] = [ex[k] for ex in features] + + return padded_batch diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3211b7c36..6ee672c8c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -673,6 +673,12 @@ class AxolotlInputConfig( "description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled" }, ) + pad_to_multiple_of: int | None = Field( + default=None, + json_schema_extra={ + "description": ("Pad each batch to a multiple of this value.") + }, + ) curriculum_sampling: bool | None = Field( default=None, json_schema_extra={