Files
axolotl/tests/test_ebft_strided_structured.py
Wing Lian c50c4acbf4 EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
2026-03-24 18:43:46 -04:00

364 lines
14 KiB
Python

"""Tests for the EBFT strided structured dataset transform and data loading."""
import pytest
from datasets import Dataset
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import PreTrainedTokenizerFast
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
@pytest.fixture
def tokenizer():
"""Create a simple word-level tokenizer — no network access needed."""
# Build a tiny vocab covering common test words
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
words = (
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
"x write code print test some string metadata noise ok good python "
"sampling abc 123 def solve return this that"
).split()
for w in words:
if w not in vocab:
vocab[w] = len(vocab)
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
backend.pre_tokenizer = pre_tokenizers.Whitespace()
tok = PreTrainedTokenizerFast(
tokenizer_object=backend,
bos_token="[BOS]",
eos_token="[EOS]",
pad_token="[PAD]",
unk_token="[UNK]",
)
return tok
@pytest.fixture
def cfg():
return DictDefault({"sequence_len": 64})
@pytest.fixture
def transform_fn_and_kwargs(cfg):
result = load_ebft("ebft_strided_structured.transform", cfg)
assert result is not None, "Failed to load ebft_strided_structured transform"
transform_fn, map_kwargs = result
return transform_fn, map_kwargs
class TestEBFTStridedStructuredTransform:
"""Tests for the dataset transform function itself."""
def test_transform_loads(self, transform_fn_and_kwargs):
transform_fn, map_kwargs = transform_fn_and_kwargs
assert callable(transform_fn)
assert "remove_columns" in map_kwargs
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
"""Transform should signal removal of all original columns."""
_, map_kwargs = transform_fn_and_kwargs
assert map_kwargs["remove_columns"] == "__all__"
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
"""Prompt tokens get labels=-100, completion tokens get real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
result = transform_fn(example, tokenizer=tokenizer)
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "prompt_length" in result
prompt_length = result["prompt_length"]
labels = result["labels"]
seq_len = len(result["input_ids"])
assert seq_len == 64, "Should be padded to sequence_len"
assert len(labels) == seq_len
assert prompt_length > 0
# Prompt tokens should be masked
for i in range(prompt_length):
assert labels[i] == -100, f"Prompt token at {i} should be -100"
# At least one completion token should have a real label
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
"""prompt_length should be the exact boundary between -100 and real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hello world", "output": "goodbye world"}
result = transform_fn(example, tokenizer=tokenizer)
prompt_length = result["prompt_length"]
labels = result["labels"]
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
assert labels[prompt_length] != -100, (
"First completion token should not be masked"
)
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
"""Padding tokens should have labels=-100 and attention_mask=0."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hi", "output": "bye"}
result = transform_fn(example, tokenizer=tokenizer)
attention_mask = result["attention_mask"]
labels = result["labels"]
real_len = sum(attention_mask)
assert real_len < 64, "Short example should have padding"
for i in range(real_len, 64):
assert attention_mask[i] == 0, (
f"Pad position {i} should have attention_mask=0"
)
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
"""Long completions should be truncated to fit sequence_len."""
transform_fn, _ = transform_fn_and_kwargs
example = {
"input": "short prompt",
"output": "x " * 500,
}
result = transform_fn(example, tokenizer=tokenizer)
assert len(result["input_ids"]) == 64
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
"""Transform should handle different field name conventions."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn(
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
result = transform_fn(
{"question": "what", "answer": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
"""Without tokenizer, should return a prompt string."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn({"input": "hello", "output": "world"})
assert "prompt" in result
assert result["prompt"] == "hello"
class TestEBFTColumnRemoval:
"""Tests for the __all__ column removal logic in the RL data path."""
def _filter_remove_columns(self, map_kwargs, dataset):
"""Reproduce the filtering logic from rl.py _load_split."""
if "remove_columns" in map_kwargs:
ds_columns = dataset.column_names
if map_kwargs["remove_columns"] == "__all__":
map_kwargs["remove_columns"] = list(ds_columns)
else:
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
return map_kwargs
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""After mapping, only tokenized columns should remain."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs) # copy
ds = Dataset.from_list(
[
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert "input" in map_kwargs["remove_columns"]
assert "output" in map_kwargs["remove_columns"]
assert "extra_field" in map_kwargs["remove_columns"]
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
assert "input_ids" in mapped.column_names
assert "labels" in mapped.column_names
assert "prompt_length" in mapped.column_names
assert "input" not in mapped.column_names
assert "output" not in mapped.column_names
assert "extra_field" not in mapped.column_names
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""Datasets with many extra metadata columns should all be cleaned up."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{
"input": "write hello world",
"output": "print hello",
"id": "abc 123",
"domain": "python",
"generation_algorithm": "sampling",
"llm_judgement": "good",
"unit_tests": "test",
"tests_execution_status": "ok",
"average_test_score": 0.95,
},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert len(map_kwargs["remove_columns"]) == 9
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
assert set(mapped.column_names) == expected_columns
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
"""No string-typed columns should remain (would crash the DataLoader)."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{"input": "test", "output": "test", "notes": "some string metadata"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
for col in mapped.column_names:
feat = mapped.features[col]
assert str(feat) != "string", (
f"Column '{col}' is still a string — would crash DataLoader"
)
def test_filter_preserves_explicit_list(self):
"""When remove_columns is an explicit list, only existing columns are kept."""
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
ds_columns = ds.column_names
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
assert map_kwargs["remove_columns"] == ["a", "b"]
assert "missing_col" not in map_kwargs["remove_columns"]
class TestMultiTurnSeparators:
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
def test_multiturn_transform_splits_turns(self):
"""Transform should store first turn as GT and remaining turns separately."""
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"sequence_len": 512})
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
out = fn(
{
"messages": [
{"role": "user", "content": "Q1"},
{"role": "assistant", "content": "A1"},
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
}
)
# ground_truth is only the first assistant turn
assert out["ground_truth"] == "A1"
# remaining_turns carries the rest for trainer-side reconstruction
assert out["remaining_turns"] == [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
def test_multiturn_gt_reconstruction_via_chat_template(self):
"""Trainer-side GT reconstruction should insert role markers between turns.
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
GT using apply_chat_template, ensuring assistant turns are separated by
role markers rather than concatenated as raw text.
"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
# Simulate the transform output
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
remaining_turns = [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
gt_conv.extend(remaining_turns)
full_gt_text = tokenizer.apply_chat_template(
gt_conv, tokenize=False, add_generation_prompt=False
)
# The full GT text should contain both assistant turns with role markers
assert "A1" in full_gt_text
assert "A2" in full_gt_text
# Raw concatenation "A1A2" should NOT appear — role markers separate them
assert "A1A2" not in full_gt_text, (
"GT reconstruction should have role markers between turns, not raw concatenation"
)
# The user turn Q2 should appear between A1 and A2
a1_pos = full_gt_text.index("A1")
a2_pos = full_gt_text.index("A2")
q2_pos = full_gt_text.index("Q2")
assert a1_pos < q2_pos < a2_pos, (
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
)
# The GT should start with the prompt
assert full_gt_text.startswith(prompt_text), (
"Full GT should start with the rendered prompt"
)
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
# remaining_turns would be [] for single-turn prompts
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
# With empty remaining_turns, trainer falls through to raw concat
# (trainer.py:302: gt_texts.append(prompt_text + gt))
gt_text = prompt_text + gt
assert gt_text.endswith("A1")
assert prompt_text in gt_text