* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
364 lines
14 KiB
Python
364 lines
14 KiB
Python
"""Tests for the EBFT strided structured dataset transform and data loading."""
|
|
|
|
import pytest
|
|
from datasets import Dataset
|
|
from tokenizers import Tokenizer, models, pre_tokenizers
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
from axolotl.prompt_strategies.ebft import load as load_ebft
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
@pytest.fixture
|
|
def tokenizer():
|
|
"""Create a simple word-level tokenizer — no network access needed."""
|
|
# Build a tiny vocab covering common test words
|
|
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
|
|
words = (
|
|
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
|
|
"x write code print test some string metadata noise ok good python "
|
|
"sampling abc 123 def solve return this that"
|
|
).split()
|
|
for w in words:
|
|
if w not in vocab:
|
|
vocab[w] = len(vocab)
|
|
|
|
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
|
|
backend.pre_tokenizer = pre_tokenizers.Whitespace()
|
|
|
|
tok = PreTrainedTokenizerFast(
|
|
tokenizer_object=backend,
|
|
bos_token="[BOS]",
|
|
eos_token="[EOS]",
|
|
pad_token="[PAD]",
|
|
unk_token="[UNK]",
|
|
)
|
|
return tok
|
|
|
|
|
|
@pytest.fixture
|
|
def cfg():
|
|
return DictDefault({"sequence_len": 64})
|
|
|
|
|
|
@pytest.fixture
|
|
def transform_fn_and_kwargs(cfg):
|
|
result = load_ebft("ebft_strided_structured.transform", cfg)
|
|
assert result is not None, "Failed to load ebft_strided_structured transform"
|
|
transform_fn, map_kwargs = result
|
|
return transform_fn, map_kwargs
|
|
|
|
|
|
class TestEBFTStridedStructuredTransform:
|
|
"""Tests for the dataset transform function itself."""
|
|
|
|
def test_transform_loads(self, transform_fn_and_kwargs):
|
|
transform_fn, map_kwargs = transform_fn_and_kwargs
|
|
assert callable(transform_fn)
|
|
assert "remove_columns" in map_kwargs
|
|
|
|
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
|
|
"""Transform should signal removal of all original columns."""
|
|
_, map_kwargs = transform_fn_and_kwargs
|
|
assert map_kwargs["remove_columns"] == "__all__"
|
|
|
|
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
|
|
"""Prompt tokens get labels=-100, completion tokens get real labels."""
|
|
transform_fn, _ = transform_fn_and_kwargs
|
|
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
|
|
result = transform_fn(example, tokenizer=tokenizer)
|
|
|
|
assert "input_ids" in result
|
|
assert "labels" in result
|
|
assert "attention_mask" in result
|
|
assert "prompt_length" in result
|
|
|
|
prompt_length = result["prompt_length"]
|
|
labels = result["labels"]
|
|
seq_len = len(result["input_ids"])
|
|
|
|
assert seq_len == 64, "Should be padded to sequence_len"
|
|
assert len(labels) == seq_len
|
|
assert prompt_length > 0
|
|
|
|
# Prompt tokens should be masked
|
|
for i in range(prompt_length):
|
|
assert labels[i] == -100, f"Prompt token at {i} should be -100"
|
|
|
|
# At least one completion token should have a real label
|
|
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
|
|
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
|
|
|
|
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
|
|
"""prompt_length should be the exact boundary between -100 and real labels."""
|
|
transform_fn, _ = transform_fn_and_kwargs
|
|
example = {"input": "hello world", "output": "goodbye world"}
|
|
result = transform_fn(example, tokenizer=tokenizer)
|
|
|
|
prompt_length = result["prompt_length"]
|
|
labels = result["labels"]
|
|
|
|
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
|
|
assert labels[prompt_length] != -100, (
|
|
"First completion token should not be masked"
|
|
)
|
|
|
|
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
|
|
"""Padding tokens should have labels=-100 and attention_mask=0."""
|
|
transform_fn, _ = transform_fn_and_kwargs
|
|
example = {"input": "hi", "output": "bye"}
|
|
result = transform_fn(example, tokenizer=tokenizer)
|
|
|
|
attention_mask = result["attention_mask"]
|
|
labels = result["labels"]
|
|
|
|
real_len = sum(attention_mask)
|
|
assert real_len < 64, "Short example should have padding"
|
|
|
|
for i in range(real_len, 64):
|
|
assert attention_mask[i] == 0, (
|
|
f"Pad position {i} should have attention_mask=0"
|
|
)
|
|
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
|
|
|
|
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
|
|
"""Long completions should be truncated to fit sequence_len."""
|
|
transform_fn, _ = transform_fn_and_kwargs
|
|
example = {
|
|
"input": "short prompt",
|
|
"output": "x " * 500,
|
|
}
|
|
result = transform_fn(example, tokenizer=tokenizer)
|
|
assert len(result["input_ids"]) == 64
|
|
|
|
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
|
|
"""Transform should handle different field name conventions."""
|
|
transform_fn, _ = transform_fn_and_kwargs
|
|
|
|
result = transform_fn(
|
|
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
|
|
)
|
|
assert result["prompt_length"] > 0
|
|
|
|
result = transform_fn(
|
|
{"question": "what", "answer": "this"}, tokenizer=tokenizer
|
|
)
|
|
assert result["prompt_length"] > 0
|
|
|
|
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
|
|
"""Without tokenizer, should return a prompt string."""
|
|
transform_fn, _ = transform_fn_and_kwargs
|
|
result = transform_fn({"input": "hello", "output": "world"})
|
|
assert "prompt" in result
|
|
assert result["prompt"] == "hello"
|
|
|
|
|
|
class TestEBFTColumnRemoval:
|
|
"""Tests for the __all__ column removal logic in the RL data path."""
|
|
|
|
def _filter_remove_columns(self, map_kwargs, dataset):
|
|
"""Reproduce the filtering logic from rl.py _load_split."""
|
|
if "remove_columns" in map_kwargs:
|
|
ds_columns = dataset.column_names
|
|
if map_kwargs["remove_columns"] == "__all__":
|
|
map_kwargs["remove_columns"] = list(ds_columns)
|
|
else:
|
|
map_kwargs["remove_columns"] = [
|
|
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
|
]
|
|
return map_kwargs
|
|
|
|
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
|
"""After mapping, only tokenized columns should remain."""
|
|
transform_fn, map_kwargs = transform_fn_and_kwargs
|
|
map_kwargs = dict(map_kwargs) # copy
|
|
|
|
ds = Dataset.from_list(
|
|
[
|
|
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
|
|
]
|
|
)
|
|
|
|
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
|
assert "input" in map_kwargs["remove_columns"]
|
|
assert "output" in map_kwargs["remove_columns"]
|
|
assert "extra_field" in map_kwargs["remove_columns"]
|
|
|
|
from functools import partial
|
|
|
|
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
|
assert "input_ids" in mapped.column_names
|
|
assert "labels" in mapped.column_names
|
|
assert "prompt_length" in mapped.column_names
|
|
assert "input" not in mapped.column_names
|
|
assert "output" not in mapped.column_names
|
|
assert "extra_field" not in mapped.column_names
|
|
|
|
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
|
"""Datasets with many extra metadata columns should all be cleaned up."""
|
|
transform_fn, map_kwargs = transform_fn_and_kwargs
|
|
map_kwargs = dict(map_kwargs)
|
|
|
|
ds = Dataset.from_list(
|
|
[
|
|
{
|
|
"input": "write hello world",
|
|
"output": "print hello",
|
|
"id": "abc 123",
|
|
"domain": "python",
|
|
"generation_algorithm": "sampling",
|
|
"llm_judgement": "good",
|
|
"unit_tests": "test",
|
|
"tests_execution_status": "ok",
|
|
"average_test_score": 0.95,
|
|
},
|
|
]
|
|
)
|
|
|
|
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
|
assert len(map_kwargs["remove_columns"]) == 9
|
|
|
|
from functools import partial
|
|
|
|
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
|
|
|
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
|
|
assert set(mapped.column_names) == expected_columns
|
|
|
|
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
|
|
"""No string-typed columns should remain (would crash the DataLoader)."""
|
|
transform_fn, map_kwargs = transform_fn_and_kwargs
|
|
map_kwargs = dict(map_kwargs)
|
|
|
|
ds = Dataset.from_list(
|
|
[
|
|
{"input": "test", "output": "test", "notes": "some string metadata"},
|
|
]
|
|
)
|
|
|
|
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
|
|
|
from functools import partial
|
|
|
|
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
|
|
|
for col in mapped.column_names:
|
|
feat = mapped.features[col]
|
|
assert str(feat) != "string", (
|
|
f"Column '{col}' is still a string — would crash DataLoader"
|
|
)
|
|
|
|
def test_filter_preserves_explicit_list(self):
|
|
"""When remove_columns is an explicit list, only existing columns are kept."""
|
|
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
|
|
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
|
|
|
|
ds_columns = ds.column_names
|
|
map_kwargs["remove_columns"] = [
|
|
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
|
]
|
|
|
|
assert map_kwargs["remove_columns"] == ["a", "b"]
|
|
assert "missing_col" not in map_kwargs["remove_columns"]
|
|
|
|
|
|
class TestMultiTurnSeparators:
|
|
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
|
|
|
|
def test_multiturn_transform_splits_turns(self):
|
|
"""Transform should store first turn as GT and remaining turns separately."""
|
|
from axolotl.prompt_strategies.ebft import load as load_ebft
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
cfg = DictDefault({"sequence_len": 512})
|
|
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
|
|
out = fn(
|
|
{
|
|
"messages": [
|
|
{"role": "user", "content": "Q1"},
|
|
{"role": "assistant", "content": "A1"},
|
|
{"role": "user", "content": "Q2"},
|
|
{"role": "assistant", "content": "A2"},
|
|
]
|
|
}
|
|
)
|
|
# ground_truth is only the first assistant turn
|
|
assert out["ground_truth"] == "A1"
|
|
# remaining_turns carries the rest for trainer-side reconstruction
|
|
assert out["remaining_turns"] == [
|
|
{"role": "user", "content": "Q2"},
|
|
{"role": "assistant", "content": "A2"},
|
|
]
|
|
|
|
def test_multiturn_gt_reconstruction_via_chat_template(self):
|
|
"""Trainer-side GT reconstruction should insert role markers between turns.
|
|
|
|
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
|
|
GT using apply_chat_template, ensuring assistant turns are separated by
|
|
role markers rather than concatenated as raw text.
|
|
"""
|
|
from transformers import AutoTokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
|
)
|
|
|
|
# Simulate the transform output
|
|
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
|
gt = "A1"
|
|
remaining_turns = [
|
|
{"role": "user", "content": "Q2"},
|
|
{"role": "assistant", "content": "A2"},
|
|
]
|
|
|
|
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
|
|
prompt_text = tokenizer.apply_chat_template(
|
|
prompt_msgs, tokenize=False, add_generation_prompt=True
|
|
)
|
|
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
|
|
gt_conv.extend(remaining_turns)
|
|
full_gt_text = tokenizer.apply_chat_template(
|
|
gt_conv, tokenize=False, add_generation_prompt=False
|
|
)
|
|
|
|
# The full GT text should contain both assistant turns with role markers
|
|
assert "A1" in full_gt_text
|
|
assert "A2" in full_gt_text
|
|
# Raw concatenation "A1A2" should NOT appear — role markers separate them
|
|
assert "A1A2" not in full_gt_text, (
|
|
"GT reconstruction should have role markers between turns, not raw concatenation"
|
|
)
|
|
# The user turn Q2 should appear between A1 and A2
|
|
a1_pos = full_gt_text.index("A1")
|
|
a2_pos = full_gt_text.index("A2")
|
|
q2_pos = full_gt_text.index("Q2")
|
|
assert a1_pos < q2_pos < a2_pos, (
|
|
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
|
|
)
|
|
# The GT should start with the prompt
|
|
assert full_gt_text.startswith(prompt_text), (
|
|
"Full GT should start with the rendered prompt"
|
|
)
|
|
|
|
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
|
|
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
|
|
from transformers import AutoTokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
|
)
|
|
|
|
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
|
gt = "A1"
|
|
# remaining_turns would be [] for single-turn prompts
|
|
|
|
prompt_text = tokenizer.apply_chat_template(
|
|
prompt_msgs, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
# With empty remaining_turns, trainer falls through to raw concat
|
|
# (trainer.py:302: gt_texts.append(prompt_text + gt))
|
|
gt_text = prompt_text + gt
|
|
assert gt_text.endswith("A1")
|
|
assert prompt_text in gt_text
|