EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
This commit is contained in:
363
tests/test_ebft_strided_structured.py
Normal file
363
tests/test_ebft_strided_structured.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Tests for the EBFT strided structured dataset transform and data loading."""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer():
|
||||
"""Create a simple word-level tokenizer — no network access needed."""
|
||||
# Build a tiny vocab covering common test words
|
||||
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
|
||||
words = (
|
||||
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
|
||||
"x write code print test some string metadata noise ok good python "
|
||||
"sampling abc 123 def solve return this that"
|
||||
).split()
|
||||
for w in words:
|
||||
if w not in vocab:
|
||||
vocab[w] = len(vocab)
|
||||
|
||||
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
|
||||
backend.pre_tokenizer = pre_tokenizers.Whitespace()
|
||||
|
||||
tok = PreTrainedTokenizerFast(
|
||||
tokenizer_object=backend,
|
||||
bos_token="[BOS]",
|
||||
eos_token="[EOS]",
|
||||
pad_token="[PAD]",
|
||||
unk_token="[UNK]",
|
||||
)
|
||||
return tok
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg():
|
||||
return DictDefault({"sequence_len": 64})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transform_fn_and_kwargs(cfg):
|
||||
result = load_ebft("ebft_strided_structured.transform", cfg)
|
||||
assert result is not None, "Failed to load ebft_strided_structured transform"
|
||||
transform_fn, map_kwargs = result
|
||||
return transform_fn, map_kwargs
|
||||
|
||||
|
||||
class TestEBFTStridedStructuredTransform:
|
||||
"""Tests for the dataset transform function itself."""
|
||||
|
||||
def test_transform_loads(self, transform_fn_and_kwargs):
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
assert callable(transform_fn)
|
||||
assert "remove_columns" in map_kwargs
|
||||
|
||||
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
|
||||
"""Transform should signal removal of all original columns."""
|
||||
_, map_kwargs = transform_fn_and_kwargs
|
||||
assert map_kwargs["remove_columns"] == "__all__"
|
||||
|
||||
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Prompt tokens get labels=-100, completion tokens get real labels."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
assert "input_ids" in result
|
||||
assert "labels" in result
|
||||
assert "attention_mask" in result
|
||||
assert "prompt_length" in result
|
||||
|
||||
prompt_length = result["prompt_length"]
|
||||
labels = result["labels"]
|
||||
seq_len = len(result["input_ids"])
|
||||
|
||||
assert seq_len == 64, "Should be padded to sequence_len"
|
||||
assert len(labels) == seq_len
|
||||
assert prompt_length > 0
|
||||
|
||||
# Prompt tokens should be masked
|
||||
for i in range(prompt_length):
|
||||
assert labels[i] == -100, f"Prompt token at {i} should be -100"
|
||||
|
||||
# At least one completion token should have a real label
|
||||
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
|
||||
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
|
||||
|
||||
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""prompt_length should be the exact boundary between -100 and real labels."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "hello world", "output": "goodbye world"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
prompt_length = result["prompt_length"]
|
||||
labels = result["labels"]
|
||||
|
||||
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
|
||||
assert labels[prompt_length] != -100, (
|
||||
"First completion token should not be masked"
|
||||
)
|
||||
|
||||
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Padding tokens should have labels=-100 and attention_mask=0."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "hi", "output": "bye"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
attention_mask = result["attention_mask"]
|
||||
labels = result["labels"]
|
||||
|
||||
real_len = sum(attention_mask)
|
||||
assert real_len < 64, "Short example should have padding"
|
||||
|
||||
for i in range(real_len, 64):
|
||||
assert attention_mask[i] == 0, (
|
||||
f"Pad position {i} should have attention_mask=0"
|
||||
)
|
||||
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
|
||||
|
||||
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Long completions should be truncated to fit sequence_len."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {
|
||||
"input": "short prompt",
|
||||
"output": "x " * 500,
|
||||
}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
assert len(result["input_ids"]) == 64
|
||||
|
||||
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Transform should handle different field name conventions."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
|
||||
result = transform_fn(
|
||||
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
|
||||
)
|
||||
assert result["prompt_length"] > 0
|
||||
|
||||
result = transform_fn(
|
||||
{"question": "what", "answer": "this"}, tokenizer=tokenizer
|
||||
)
|
||||
assert result["prompt_length"] > 0
|
||||
|
||||
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
|
||||
"""Without tokenizer, should return a prompt string."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
result = transform_fn({"input": "hello", "output": "world"})
|
||||
assert "prompt" in result
|
||||
assert result["prompt"] == "hello"
|
||||
|
||||
|
||||
class TestEBFTColumnRemoval:
|
||||
"""Tests for the __all__ column removal logic in the RL data path."""
|
||||
|
||||
def _filter_remove_columns(self, map_kwargs, dataset):
|
||||
"""Reproduce the filtering logic from rl.py _load_split."""
|
||||
if "remove_columns" in map_kwargs:
|
||||
ds_columns = dataset.column_names
|
||||
if map_kwargs["remove_columns"] == "__all__":
|
||||
map_kwargs["remove_columns"] = list(ds_columns)
|
||||
else:
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
return map_kwargs
|
||||
|
||||
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""After mapping, only tokenized columns should remain."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs) # copy
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
assert "input" in map_kwargs["remove_columns"]
|
||||
assert "output" in map_kwargs["remove_columns"]
|
||||
assert "extra_field" in map_kwargs["remove_columns"]
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
assert "input_ids" in mapped.column_names
|
||||
assert "labels" in mapped.column_names
|
||||
assert "prompt_length" in mapped.column_names
|
||||
assert "input" not in mapped.column_names
|
||||
assert "output" not in mapped.column_names
|
||||
assert "extra_field" not in mapped.column_names
|
||||
|
||||
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Datasets with many extra metadata columns should all be cleaned up."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs)
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"input": "write hello world",
|
||||
"output": "print hello",
|
||||
"id": "abc 123",
|
||||
"domain": "python",
|
||||
"generation_algorithm": "sampling",
|
||||
"llm_judgement": "good",
|
||||
"unit_tests": "test",
|
||||
"tests_execution_status": "ok",
|
||||
"average_test_score": 0.95,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
assert len(map_kwargs["remove_columns"]) == 9
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
|
||||
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
|
||||
assert set(mapped.column_names) == expected_columns
|
||||
|
||||
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""No string-typed columns should remain (would crash the DataLoader)."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs)
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{"input": "test", "output": "test", "notes": "some string metadata"},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
|
||||
for col in mapped.column_names:
|
||||
feat = mapped.features[col]
|
||||
assert str(feat) != "string", (
|
||||
f"Column '{col}' is still a string — would crash DataLoader"
|
||||
)
|
||||
|
||||
def test_filter_preserves_explicit_list(self):
|
||||
"""When remove_columns is an explicit list, only existing columns are kept."""
|
||||
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
|
||||
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
|
||||
|
||||
ds_columns = ds.column_names
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
|
||||
assert map_kwargs["remove_columns"] == ["a", "b"]
|
||||
assert "missing_col" not in map_kwargs["remove_columns"]
|
||||
|
||||
|
||||
class TestMultiTurnSeparators:
|
||||
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
|
||||
|
||||
def test_multiturn_transform_splits_turns(self):
|
||||
"""Transform should store first turn as GT and remaining turns separately."""
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = DictDefault({"sequence_len": 512})
|
||||
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
|
||||
out = fn(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Q1"},
|
||||
{"role": "assistant", "content": "A1"},
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
}
|
||||
)
|
||||
# ground_truth is only the first assistant turn
|
||||
assert out["ground_truth"] == "A1"
|
||||
# remaining_turns carries the rest for trainer-side reconstruction
|
||||
assert out["remaining_turns"] == [
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
|
||||
def test_multiturn_gt_reconstruction_via_chat_template(self):
|
||||
"""Trainer-side GT reconstruction should insert role markers between turns.
|
||||
|
||||
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
|
||||
GT using apply_chat_template, ensuring assistant turns are separated by
|
||||
role markers rather than concatenated as raw text.
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
# Simulate the transform output
|
||||
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
||||
gt = "A1"
|
||||
remaining_turns = [
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
|
||||
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
|
||||
gt_conv.extend(remaining_turns)
|
||||
full_gt_text = tokenizer.apply_chat_template(
|
||||
gt_conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
|
||||
# The full GT text should contain both assistant turns with role markers
|
||||
assert "A1" in full_gt_text
|
||||
assert "A2" in full_gt_text
|
||||
# Raw concatenation "A1A2" should NOT appear — role markers separate them
|
||||
assert "A1A2" not in full_gt_text, (
|
||||
"GT reconstruction should have role markers between turns, not raw concatenation"
|
||||
)
|
||||
# The user turn Q2 should appear between A1 and A2
|
||||
a1_pos = full_gt_text.index("A1")
|
||||
a2_pos = full_gt_text.index("A2")
|
||||
q2_pos = full_gt_text.index("Q2")
|
||||
assert a1_pos < q2_pos < a2_pos, (
|
||||
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
|
||||
)
|
||||
# The GT should start with the prompt
|
||||
assert full_gt_text.startswith(prompt_text), (
|
||||
"Full GT should start with the rendered prompt"
|
||||
)
|
||||
|
||||
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
|
||||
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
||||
gt = "A1"
|
||||
# remaining_turns would be [] for single-turn prompts
|
||||
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# With empty remaining_turns, trainer falls through to raw concat
|
||||
# (trainer.py:302: gt_texts.append(prompt_text + gt))
|
||||
gt_text = prompt_text + gt
|
||||
assert gt_text.endswith("A1")
|
||||
assert prompt_text in gt_text
|
||||
Reference in New Issue
Block a user