Files
axolotl/examples/ebft/ebft_strided_structured.py
Wing Lian c50c4acbf4 EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
2026-03-24 18:43:46 -04:00

81 lines
2.8 KiB
Python

"""
Dataset transform for structured (prompt, completion) data with strided EBFT.
Tokenizes prompt and completion separately, concatenates into a single
input_ids sequence, and marks prompt tokens with labels=-100 so the
strided trainer knows where to place anchors (completion span only).
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
"""
def transform(cfg, *args, **kwargs):
seq_len = cfg.sequence_len
def transform_fn(example, tokenizer=None):
# Extract prompt and completion from the example
prompt_text = example.get(
"input", example.get("prompt", example.get("question", ""))
)
completion_text = example.get(
"output", example.get("completion", example.get("answer", ""))
)
if tokenizer is None:
return {"prompt": prompt_text}
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# Tokenize prompt and completion separately
prompt_enc = tokenizer(
prompt_text,
truncation=False,
add_special_tokens=True,
return_tensors=None,
)
completion_enc = tokenizer(
completion_text,
truncation=False,
add_special_tokens=False,
return_tensors=None,
)
prompt_ids = prompt_enc["input_ids"]
completion_ids = completion_enc["input_ids"]
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
total_len = len(prompt_ids) + len(completion_ids)
if total_len > seq_len:
# Truncate completion first, then prompt if needed
max_completion = seq_len - len(prompt_ids)
if max_completion < 1:
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
prompt_ids = prompt_ids[: seq_len - 1]
completion_ids = completion_ids[:1]
else:
completion_ids = completion_ids[:max_completion]
input_ids = prompt_ids + completion_ids
prompt_length = len(prompt_ids)
# Labels: -100 for prompt tokens, input_ids for completion tokens
labels = [-100] * prompt_length + completion_ids
# Pad to seq_len
pad_len = seq_len - len(input_ids)
attention_mask = [1] * len(input_ids) + [0] * pad_len
labels = labels + [-100] * pad_len
input_ids = input_ids + [pad_id] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompt_length": prompt_length,
}
# Signal to remove all original columns (filtered to existing ones at map time)
return transform_fn, {
"remove_columns": "__all__",
}