EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
This commit is contained in:
80
examples/ebft/ebft_strided_structured.py
Normal file
80
examples/ebft/ebft_strided_structured.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Dataset transform for structured (prompt, completion) data with strided EBFT.
|
||||
|
||||
Tokenizes prompt and completion separately, concatenates into a single
|
||||
input_ids sequence, and marks prompt tokens with labels=-100 so the
|
||||
strided trainer knows where to place anchors (completion span only).
|
||||
|
||||
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
# Extract prompt and completion from the example
|
||||
prompt_text = example.get(
|
||||
"input", example.get("prompt", example.get("question", ""))
|
||||
)
|
||||
completion_text = example.get(
|
||||
"output", example.get("completion", example.get("answer", ""))
|
||||
)
|
||||
|
||||
if tokenizer is None:
|
||||
return {"prompt": prompt_text}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize prompt and completion separately
|
||||
prompt_enc = tokenizer(
|
||||
prompt_text,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
completion_enc = tokenizer(
|
||||
completion_text,
|
||||
truncation=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
prompt_ids = prompt_enc["input_ids"]
|
||||
completion_ids = completion_enc["input_ids"]
|
||||
|
||||
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
|
||||
total_len = len(prompt_ids) + len(completion_ids)
|
||||
if total_len > seq_len:
|
||||
# Truncate completion first, then prompt if needed
|
||||
max_completion = seq_len - len(prompt_ids)
|
||||
if max_completion < 1:
|
||||
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
|
||||
prompt_ids = prompt_ids[: seq_len - 1]
|
||||
completion_ids = completion_ids[:1]
|
||||
else:
|
||||
completion_ids = completion_ids[:max_completion]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
# Labels: -100 for prompt tokens, input_ids for completion tokens
|
||||
labels = [-100] * prompt_length + completion_ids
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
# Signal to remove all original columns (filtered to existing ones at map time)
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
Reference in New Issue
Block a user