Files
Wing Lian e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00
..

Finetuning LLMs to output audio

In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.

The finetune.yml withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.

Dataset pre-processing for pre-training

If you are adding another voice in English, please jump ahead to finetuning pre-processing.

For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.

Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.

import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os

my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"

dsn = my_original_dataset_name

snapshot_download(
    repo_id=dsn,
    repo_type="dataset",
    revision="main",
    max_workers=64,
)


ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]

model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")

def tokenise_audio(waveform):
  waveform = torch.from_numpy(waveform).unsqueeze(0)
  waveform = waveform.to(dtype=torch.float32)
  resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
  waveform = resample_transform(waveform)

  waveform = waveform.unsqueeze(0).to("cuda")

  #generate the codes from snac
  with torch.inference_mode():
    codes = model.encode(waveform)

  all_codes = []
  for i in range(codes[0].shape[1]):
    all_codes.append(codes[0][0][i].item()+128266)
    all_codes.append(codes[1][0][2*i].item()+128266+4096)
    all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
    all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
    all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
    all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
    all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))


  return all_codes

def add_codes(example):
    # Always initialize codes_list to None
    codes_list = None

    try:
        answer_audio = example.get("audio")
        # If there's a valid audio array, tokenise it
        if answer_audio and "array" in answer_audio:
            audio_array = answer_audio["array"]
            codes_list = tokenise_audio(audio_array)
    except Exception as e:
        print(f"Skipping row due to error: {e}")
        # Keep codes_list as None if we fail
    example["codes_list"] = codes_list

    return example

ds = ds.map(add_codes, remove_columns=["audio"])

#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009

start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2

start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4

start_of_ai = tokeniser_length + 5
end_of_ai =  tokeniser_length + 6
pad_token = tokeniser_length + 7

audio_tokens_start = tokeniser_length + 10

tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"


tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2

ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)

#@title Create Input Ids
def remove_duplicate_frames(example):
    vals = example["codes_list"]
    if len(vals) % 7 != 0:
        raise ValueError("Input list length must be divisible by 7")

    result = vals[:7]

    removed_frames = 0

    for i in range(7, len(vals), 7):
        current_first = vals[i]
        previous_first = result[-7]

        if current_first != previous_first:
            result.extend(vals[i:i+7])
        else:
            removed_frames += 1

    example["codes_list"] = result

    return example

ds = ds.map(remove_duplicate_frames, num_proc=num_proc)


def create_input_ids(example):
    text_ids = tokenizer.encode({example['text']},  add_special_tokens=True)
    text_ids.append(end_of_text)
    example["text_tokens"] = text_ids
    input_ids = (
        [start_of_human]
        + example["text_tokens"]
        + [end_of_human]
        + [start_of_ai]
        + [start_of_speech]
        + example["codes_list"]
        + [end_of_speech]
        + [end_of_ai]
    )
    example["input_ids"] = input_ids
    example["labels"] = input_ids
    example["attention_mask"] = [1] * len(input_ids)

    return example

ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])

#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]

ds = ds.remove_columns(columns_to_remove)

ds.push_to_hub(name_to_push_dataset_to)

Finetune pre-processing

Use this code to add a new voice.

import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os

my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"

dsn = my_original_dataset_name

snapshot_download(
    repo_id=dsn,
    repo_type="dataset",
    revision="main",
    max_workers=64,
)


ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]

model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")

def tokenise_audio(waveform):
  waveform = torch.from_numpy(waveform).unsqueeze(0)
  waveform = waveform.to(dtype=torch.float32)
  resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
  waveform = resample_transform(waveform)

  waveform = waveform.unsqueeze(0).to("cuda")

  #generate the codes from snac
  with torch.inference_mode():
    codes = model.encode(waveform)

  all_codes = []
  for i in range(codes[0].shape[1]):
    all_codes.append(codes[0][0][i].item()+128266)
    all_codes.append(codes[1][0][2*i].item()+128266+4096)
    all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
    all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
    all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
    all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
    all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))


  return all_codes

def add_codes(example):
    # Always initialize codes_list to None
    codes_list = None

    try:
        answer_audio = example.get("audio")
        # If there's a valid audio array, tokenise it
        if answer_audio and "array" in answer_audio:
            audio_array = answer_audio["array"]
            codes_list = tokenise_audio(audio_array)
    except Exception as e:
        print(f"Skipping row due to error: {e}")
        # Keep codes_list as None if we fail
    example["codes_list"] = codes_list

    return example

ds = ds.map(add_codes, remove_columns=["audio"])

#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009

start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2

start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4

start_of_ai = tokeniser_length + 5
end_of_ai =  tokeniser_length + 6
pad_token = tokeniser_length + 7

audio_tokens_start = tokeniser_length + 10

tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"


tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2

ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)

#@title Create Input Ids
def remove_duplicate_frames(example):
    vals = example["codes_list"]
    if len(vals) % 7 != 0:
        raise ValueError("Input list length must be divisible by 7")

    result = vals[:7]

    removed_frames = 0

    for i in range(7, len(vals), 7):
        current_first = vals[i]
        previous_first = result[-7]

        if current_first != previous_first:
            result.extend(vals[i:i+7])
        else:
            removed_frames += 1

    example["codes_list"] = result

    return example

ds = ds.map(remove_duplicate_frames, num_proc=num_proc)

tok_info = '''*** HERE you can modify the text prompt
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
f"{example["source"]}:  {example["text"]}", as is passed.
'''
print(tok_info)

def create_input_ids(example):
    text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}",  add_special_tokens=True)
    text_ids.append(end_of_text)
    example["text_tokens"] = text_ids
    input_ids = (
        [start_of_human]
        + example["text_tokens"]
        + [end_of_human]
        + [start_of_ai]
        + [start_of_speech]
        + example["codes_list"]
        + [end_of_speech]
        + [end_of_ai]
    )
    example["input_ids"] = input_ids
    example["labels"] = input_ids
    example["attention_mask"] = [1] * len(input_ids)

    return example

ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])

#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]

ds = ds.remove_columns(columns_to_remove)

ds.push_to_hub(name_to_push_dataset_to)

Training

After preprocessing is done, fill out the blanks in finetune.yml and simply run axolotl train finetune.yml

Inference

For inference, please refer to the original orpheus github.