EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
This commit is contained in:
294
tests/test_ebft_kernels.py
Normal file
294
tests/test_ebft_kernels.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Correctness tests for fused EBFT Triton kernels."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.core.trainers.ebft.kernels import (
|
||||
fused_cosine_similarity,
|
||||
fused_diversity_penalty,
|
||||
fused_log_softmax_gather,
|
||||
fused_reinforce_loss,
|
||||
)
|
||||
|
||||
# Skip all tests if CUDA not available
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. fused_log_softmax_gather
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedLogSoftmaxGather:
|
||||
def _reference(self, logits, labels):
|
||||
"""PyTorch reference: log_softmax + gather."""
|
||||
lp = F.log_softmax(logits.float(), dim=-1)
|
||||
return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
def test_basic_correctness(self):
|
||||
B, S, V = 2, 16, 1024
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_large_vocab(self):
|
||||
"""Test with realistic vocab size (128K)."""
|
||||
B, S, V = 1, 8, 128256
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_fp32_input(self):
|
||||
B, S, V = 2, 8, 512
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_output_is_negative(self):
|
||||
"""log_softmax values should always be <= 0."""
|
||||
B, S, V = 4, 32, 2048
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (B, S), device=DEVICE)
|
||||
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
|
||||
|
||||
def test_extreme_logits(self):
|
||||
"""Test numerical stability with very large/small logits."""
|
||||
B, S, V = 1, 4, 256
|
||||
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
|
||||
logits[:, 0, :] = 1000.0 # very large
|
||||
logits[:, 1, :] = -1000.0 # very small
|
||||
logits[:, 2, 0] = 1000.0 # one hot-ish
|
||||
labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
|
||||
|
||||
ref = self._reference(logits, labels)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
assert torch.isfinite(out).all(), "Should handle extreme values"
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with pre-flattened (N, V) input."""
|
||||
N, V = 64, 4096
|
||||
logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
|
||||
labels = torch.randint(0, V, (N,), device=DEVICE)
|
||||
|
||||
ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
|
||||
out = fused_log_softmax_gather(logits, labels)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. fused_reinforce_loss
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedReinforceLoss:
|
||||
def _reference(self, logps, advantages, mask):
|
||||
"""PyTorch reference implementation."""
|
||||
loss_per_token = -logps * advantages
|
||||
return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
|
||||
|
||||
def test_basic_correctness(self):
|
||||
N = 1024
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with (B, S) shaped inputs."""
|
||||
B, S = 4, 256
|
||||
logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_all_masked(self):
|
||||
"""All-zero mask should return 0."""
|
||||
N = 512
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
|
||||
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
assert out.item() == 0.0
|
||||
|
||||
def test_all_unmasked(self):
|
||||
N = 512
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_large(self):
|
||||
"""Test with realistic size (4 * 3000 tokens)."""
|
||||
N = 12000
|
||||
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
|
||||
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
|
||||
|
||||
ref = self._reference(logps, advs, mask)
|
||||
out = fused_reinforce_loss(logps, advs, mask)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. fused_cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedCosineSimilarity:
|
||||
def test_basic_correctness(self):
|
||||
N, D = 64, 256
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
|
||||
out = fused_cosine_similarity(a, b)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_batched(self):
|
||||
"""Test with (B, N, NB, D) shaped input."""
|
||||
B, N, NB, D = 2, 4, 16, 512
|
||||
a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
|
||||
out = fused_cosine_similarity(a, b)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_identical_vectors(self):
|
||||
"""Identical vectors should give similarity = 1."""
|
||||
N, D = 16, 128
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
|
||||
|
||||
out = fused_cosine_similarity(a, a)
|
||||
torch.testing.assert_close(
|
||||
out,
|
||||
torch.ones(N, device=DEVICE, dtype=torch.float32),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
"""Orthogonal vectors should give similarity = 0."""
|
||||
D = 128
|
||||
a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
|
||||
b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
|
||||
a[0, 0] = 1.0
|
||||
b[0, 1] = 1.0
|
||||
|
||||
out = fused_cosine_similarity(a, b)
|
||||
assert abs(out.item()) < 1e-5
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
"""Opposite vectors should give similarity = -1."""
|
||||
N, D = 8, 64
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
|
||||
out = fused_cosine_similarity(a, -a)
|
||||
torch.testing.assert_close(
|
||||
out,
|
||||
-torch.ones(N, device=DEVICE, dtype=torch.float32),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
def test_large_dimension(self):
|
||||
"""Test with large feature dimension (multi-layer concatenated features)."""
|
||||
N, D = 32, 4608 # 3 layers * 1536 hidden
|
||||
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
|
||||
out = fused_cosine_similarity(a, b)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. fused_diversity_penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFusedDiversityPenalty:
|
||||
def _reference(self, embeddings):
|
||||
"""PyTorch reference: bmm + mask diagonal + mean."""
|
||||
B, N, D = embeddings.shape
|
||||
sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
|
||||
eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
|
||||
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
|
||||
return sims.sum(dim=-1) / (N - 1)
|
||||
|
||||
def test_basic_correctness(self):
|
||||
B, N, D = 4, 4, 256
|
||||
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = self._reference(emb)
|
||||
out = fused_diversity_penalty(emb)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_two_samples(self):
|
||||
"""With n=2, diversity = dot(a, b) for each."""
|
||||
B, D = 3, 128
|
||||
emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
|
||||
|
||||
ref = self._reference(emb)
|
||||
out = fused_diversity_penalty(emb)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_identical_samples(self):
|
||||
"""All identical samples should give max diversity."""
|
||||
B, N, D = 2, 4, 64
|
||||
vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
|
||||
emb = vec.expand(B, N, D).contiguous()
|
||||
|
||||
out = fused_diversity_penalty(emb)
|
||||
# Should be ||vec||^2 for each (self-excluded mean of identical dot products)
|
||||
expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
|
||||
torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_large(self):
|
||||
"""Test with realistic EBFT dimensions."""
|
||||
B, N, D = 1, 4, 4608 # multi-layer features
|
||||
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
|
||||
|
||||
ref = self._reference(emb)
|
||||
out = fused_diversity_penalty(emb)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_single_sample_returns_zeros(self):
|
||||
"""N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
|
||||
B, D = 3, 128
|
||||
emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
|
||||
out = fused_diversity_penalty(emb)
|
||||
assert (out == 0).all(), "N=1 diversity should be exactly zero"
|
||||
363
tests/test_ebft_strided_structured.py
Normal file
363
tests/test_ebft_strided_structured.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Tests for the EBFT strided structured dataset transform and data loading."""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer():
|
||||
"""Create a simple word-level tokenizer — no network access needed."""
|
||||
# Build a tiny vocab covering common test words
|
||||
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
|
||||
words = (
|
||||
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
|
||||
"x write code print test some string metadata noise ok good python "
|
||||
"sampling abc 123 def solve return this that"
|
||||
).split()
|
||||
for w in words:
|
||||
if w not in vocab:
|
||||
vocab[w] = len(vocab)
|
||||
|
||||
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
|
||||
backend.pre_tokenizer = pre_tokenizers.Whitespace()
|
||||
|
||||
tok = PreTrainedTokenizerFast(
|
||||
tokenizer_object=backend,
|
||||
bos_token="[BOS]",
|
||||
eos_token="[EOS]",
|
||||
pad_token="[PAD]",
|
||||
unk_token="[UNK]",
|
||||
)
|
||||
return tok
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg():
|
||||
return DictDefault({"sequence_len": 64})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transform_fn_and_kwargs(cfg):
|
||||
result = load_ebft("ebft_strided_structured.transform", cfg)
|
||||
assert result is not None, "Failed to load ebft_strided_structured transform"
|
||||
transform_fn, map_kwargs = result
|
||||
return transform_fn, map_kwargs
|
||||
|
||||
|
||||
class TestEBFTStridedStructuredTransform:
|
||||
"""Tests for the dataset transform function itself."""
|
||||
|
||||
def test_transform_loads(self, transform_fn_and_kwargs):
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
assert callable(transform_fn)
|
||||
assert "remove_columns" in map_kwargs
|
||||
|
||||
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
|
||||
"""Transform should signal removal of all original columns."""
|
||||
_, map_kwargs = transform_fn_and_kwargs
|
||||
assert map_kwargs["remove_columns"] == "__all__"
|
||||
|
||||
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Prompt tokens get labels=-100, completion tokens get real labels."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
assert "input_ids" in result
|
||||
assert "labels" in result
|
||||
assert "attention_mask" in result
|
||||
assert "prompt_length" in result
|
||||
|
||||
prompt_length = result["prompt_length"]
|
||||
labels = result["labels"]
|
||||
seq_len = len(result["input_ids"])
|
||||
|
||||
assert seq_len == 64, "Should be padded to sequence_len"
|
||||
assert len(labels) == seq_len
|
||||
assert prompt_length > 0
|
||||
|
||||
# Prompt tokens should be masked
|
||||
for i in range(prompt_length):
|
||||
assert labels[i] == -100, f"Prompt token at {i} should be -100"
|
||||
|
||||
# At least one completion token should have a real label
|
||||
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
|
||||
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
|
||||
|
||||
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""prompt_length should be the exact boundary between -100 and real labels."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "hello world", "output": "goodbye world"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
prompt_length = result["prompt_length"]
|
||||
labels = result["labels"]
|
||||
|
||||
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
|
||||
assert labels[prompt_length] != -100, (
|
||||
"First completion token should not be masked"
|
||||
)
|
||||
|
||||
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Padding tokens should have labels=-100 and attention_mask=0."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {"input": "hi", "output": "bye"}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
|
||||
attention_mask = result["attention_mask"]
|
||||
labels = result["labels"]
|
||||
|
||||
real_len = sum(attention_mask)
|
||||
assert real_len < 64, "Short example should have padding"
|
||||
|
||||
for i in range(real_len, 64):
|
||||
assert attention_mask[i] == 0, (
|
||||
f"Pad position {i} should have attention_mask=0"
|
||||
)
|
||||
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
|
||||
|
||||
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Long completions should be truncated to fit sequence_len."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
example = {
|
||||
"input": "short prompt",
|
||||
"output": "x " * 500,
|
||||
}
|
||||
result = transform_fn(example, tokenizer=tokenizer)
|
||||
assert len(result["input_ids"]) == 64
|
||||
|
||||
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Transform should handle different field name conventions."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
|
||||
result = transform_fn(
|
||||
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
|
||||
)
|
||||
assert result["prompt_length"] > 0
|
||||
|
||||
result = transform_fn(
|
||||
{"question": "what", "answer": "this"}, tokenizer=tokenizer
|
||||
)
|
||||
assert result["prompt_length"] > 0
|
||||
|
||||
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
|
||||
"""Without tokenizer, should return a prompt string."""
|
||||
transform_fn, _ = transform_fn_and_kwargs
|
||||
result = transform_fn({"input": "hello", "output": "world"})
|
||||
assert "prompt" in result
|
||||
assert result["prompt"] == "hello"
|
||||
|
||||
|
||||
class TestEBFTColumnRemoval:
|
||||
"""Tests for the __all__ column removal logic in the RL data path."""
|
||||
|
||||
def _filter_remove_columns(self, map_kwargs, dataset):
|
||||
"""Reproduce the filtering logic from rl.py _load_split."""
|
||||
if "remove_columns" in map_kwargs:
|
||||
ds_columns = dataset.column_names
|
||||
if map_kwargs["remove_columns"] == "__all__":
|
||||
map_kwargs["remove_columns"] = list(ds_columns)
|
||||
else:
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
return map_kwargs
|
||||
|
||||
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""After mapping, only tokenized columns should remain."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs) # copy
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
assert "input" in map_kwargs["remove_columns"]
|
||||
assert "output" in map_kwargs["remove_columns"]
|
||||
assert "extra_field" in map_kwargs["remove_columns"]
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
assert "input_ids" in mapped.column_names
|
||||
assert "labels" in mapped.column_names
|
||||
assert "prompt_length" in mapped.column_names
|
||||
assert "input" not in mapped.column_names
|
||||
assert "output" not in mapped.column_names
|
||||
assert "extra_field" not in mapped.column_names
|
||||
|
||||
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""Datasets with many extra metadata columns should all be cleaned up."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs)
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"input": "write hello world",
|
||||
"output": "print hello",
|
||||
"id": "abc 123",
|
||||
"domain": "python",
|
||||
"generation_algorithm": "sampling",
|
||||
"llm_judgement": "good",
|
||||
"unit_tests": "test",
|
||||
"tests_execution_status": "ok",
|
||||
"average_test_score": 0.95,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
assert len(map_kwargs["remove_columns"]) == 9
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
|
||||
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
|
||||
assert set(mapped.column_names) == expected_columns
|
||||
|
||||
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
|
||||
"""No string-typed columns should remain (would crash the DataLoader)."""
|
||||
transform_fn, map_kwargs = transform_fn_and_kwargs
|
||||
map_kwargs = dict(map_kwargs)
|
||||
|
||||
ds = Dataset.from_list(
|
||||
[
|
||||
{"input": "test", "output": "test", "notes": "some string metadata"},
|
||||
]
|
||||
)
|
||||
|
||||
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
|
||||
|
||||
from functools import partial
|
||||
|
||||
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
|
||||
|
||||
for col in mapped.column_names:
|
||||
feat = mapped.features[col]
|
||||
assert str(feat) != "string", (
|
||||
f"Column '{col}' is still a string — would crash DataLoader"
|
||||
)
|
||||
|
||||
def test_filter_preserves_explicit_list(self):
|
||||
"""When remove_columns is an explicit list, only existing columns are kept."""
|
||||
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
|
||||
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
|
||||
|
||||
ds_columns = ds.column_names
|
||||
map_kwargs["remove_columns"] = [
|
||||
c for c in map_kwargs["remove_columns"] if c in ds_columns
|
||||
]
|
||||
|
||||
assert map_kwargs["remove_columns"] == ["a", "b"]
|
||||
assert "missing_col" not in map_kwargs["remove_columns"]
|
||||
|
||||
|
||||
class TestMultiTurnSeparators:
|
||||
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
|
||||
|
||||
def test_multiturn_transform_splits_turns(self):
|
||||
"""Transform should store first turn as GT and remaining turns separately."""
|
||||
from axolotl.prompt_strategies.ebft import load as load_ebft
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = DictDefault({"sequence_len": 512})
|
||||
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
|
||||
out = fn(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Q1"},
|
||||
{"role": "assistant", "content": "A1"},
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
}
|
||||
)
|
||||
# ground_truth is only the first assistant turn
|
||||
assert out["ground_truth"] == "A1"
|
||||
# remaining_turns carries the rest for trainer-side reconstruction
|
||||
assert out["remaining_turns"] == [
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
|
||||
def test_multiturn_gt_reconstruction_via_chat_template(self):
|
||||
"""Trainer-side GT reconstruction should insert role markers between turns.
|
||||
|
||||
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
|
||||
GT using apply_chat_template, ensuring assistant turns are separated by
|
||||
role markers rather than concatenated as raw text.
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
# Simulate the transform output
|
||||
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
||||
gt = "A1"
|
||||
remaining_turns = [
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
]
|
||||
|
||||
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
|
||||
gt_conv.extend(remaining_turns)
|
||||
full_gt_text = tokenizer.apply_chat_template(
|
||||
gt_conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
|
||||
# The full GT text should contain both assistant turns with role markers
|
||||
assert "A1" in full_gt_text
|
||||
assert "A2" in full_gt_text
|
||||
# Raw concatenation "A1A2" should NOT appear — role markers separate them
|
||||
assert "A1A2" not in full_gt_text, (
|
||||
"GT reconstruction should have role markers between turns, not raw concatenation"
|
||||
)
|
||||
# The user turn Q2 should appear between A1 and A2
|
||||
a1_pos = full_gt_text.index("A1")
|
||||
a2_pos = full_gt_text.index("A2")
|
||||
q2_pos = full_gt_text.index("Q2")
|
||||
assert a1_pos < q2_pos < a2_pos, (
|
||||
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
|
||||
)
|
||||
# The GT should start with the prompt
|
||||
assert full_gt_text.startswith(prompt_text), (
|
||||
"Full GT should start with the rendered prompt"
|
||||
)
|
||||
|
||||
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
|
||||
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
prompt_msgs = [{"role": "user", "content": "Q1"}]
|
||||
gt = "A1"
|
||||
# remaining_turns would be [] for single-turn prompts
|
||||
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# With empty remaining_turns, trainer falls through to raw concat
|
||||
# (trainer.py:302: gt_texts.append(prompt_text + gt))
|
||||
gt_text = prompt_text + gt
|
||||
assert gt_text.endswith("A1")
|
||||
assert prompt_text in gt_text
|
||||
158
tests/test_http_weight_sync.py
Normal file
158
tests/test_http_weight_sync.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
|
||||
|
||||
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
|
||||
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.utils.weight_serde import (
|
||||
decode_from_http,
|
||||
decode_from_ipc,
|
||||
encode_for_http,
|
||||
encode_for_ipc,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 1: trainer → serve endpoint (HTTP with base64)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHttpEncodeRoundTrip:
|
||||
"""Test encode_for_http / decode_from_http."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_round_trip_dtype(self, dtype):
|
||||
original = torch.randn(32, 64, dtype=dtype)
|
||||
entry = encode_for_http("layer.weight", original)
|
||||
name, decoded = decode_from_http(entry)
|
||||
|
||||
assert name == "layer.weight"
|
||||
assert decoded.dtype == dtype
|
||||
assert decoded.shape == original.shape
|
||||
if dtype == torch.bfloat16:
|
||||
# bf16→fp16→bf16 loses some precision
|
||||
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
|
||||
|
||||
def test_bfloat16_wire_format_is_fp16(self):
|
||||
"""bf16 tensors should be sent as fp16 on the wire."""
|
||||
import base64
|
||||
|
||||
original = torch.randn(8, 16, dtype=torch.bfloat16)
|
||||
entry = encode_for_http("w", original)
|
||||
raw = base64.b64decode(entry["data"])
|
||||
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
|
||||
assert len(raw) == 8 * 16 * 2
|
||||
# dtype field should preserve original dtype for reconstruction
|
||||
assert entry["dtype"] == "torch.bfloat16"
|
||||
|
||||
def test_multidimensional_shapes(self):
|
||||
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
|
||||
original = torch.randn(*shape, dtype=torch.bfloat16)
|
||||
entry = encode_for_http("w", original)
|
||||
_, decoded = decode_from_http(entry)
|
||||
assert decoded.shape == original.shape
|
||||
assert decoded.dtype == torch.bfloat16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIpcEncodeRoundTrip:
|
||||
"""Test encode_for_ipc / decode_from_ipc."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_round_trip_dtype(self, dtype):
|
||||
original = torch.randn(32, 64, dtype=dtype)
|
||||
entry = encode_for_ipc("layer.weight", original)
|
||||
name, decoded = decode_from_ipc(entry)
|
||||
|
||||
assert name == "layer.weight"
|
||||
assert decoded.dtype == dtype
|
||||
assert decoded.shape == original.shape
|
||||
if dtype == torch.bfloat16:
|
||||
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
|
||||
|
||||
def test_bfloat16_ipc_wire_is_fp16(self):
|
||||
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
|
||||
original = torch.randn(4, 8, dtype=torch.bfloat16)
|
||||
entry = encode_for_ipc("w", original)
|
||||
assert entry["dtype"] == "float16"
|
||||
assert entry["target_dtype"] == "bfloat16"
|
||||
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
|
||||
|
||||
def test_fp32_has_no_target_dtype_mismatch(self):
|
||||
original = torch.randn(4, 8, dtype=torch.float32)
|
||||
entry = encode_for_ipc("w", original)
|
||||
assert entry["dtype"] == "float32"
|
||||
assert entry["target_dtype"] == "float32"
|
||||
|
||||
def test_worker_handles_missing_target_dtype(self):
|
||||
"""Backward compat: older serve code may not send target_dtype."""
|
||||
entry = {
|
||||
"name": "w",
|
||||
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
|
||||
"dtype": "float32",
|
||||
"shape": [4, 8],
|
||||
# no target_dtype key
|
||||
}
|
||||
name, decoded = decode_from_ipc(entry)
|
||||
assert decoded.dtype == torch.float32
|
||||
assert decoded.shape == (4, 8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full pipeline: trainer → serve → worker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFullPipelineRoundTrip:
|
||||
"""End-to-end: trainer → serve → worker."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_three_stage_round_trip(self, dtype):
|
||||
"""Tensor survives trainer→serve→worker with correct dtype and values."""
|
||||
original = torch.randn(16, 32, dtype=dtype)
|
||||
|
||||
# Stage 1: trainer encodes for HTTP
|
||||
http_entry = encode_for_http("model.layers.0.weight", original)
|
||||
|
||||
# Stage 2: serve decodes HTTP, re-encodes for IPC
|
||||
name, at_serve = decode_from_http(http_entry)
|
||||
ipc_entry = encode_for_ipc(name, at_serve)
|
||||
|
||||
# Stage 3: worker decodes IPC
|
||||
_, at_worker = decode_from_ipc(ipc_entry)
|
||||
|
||||
assert at_worker.dtype == dtype
|
||||
assert at_worker.shape == original.shape
|
||||
if dtype == torch.bfloat16:
|
||||
# Two bf16→fp16→bf16 hops compound precision loss slightly
|
||||
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
|
||||
else:
|
||||
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
|
||||
|
||||
def test_bfloat16_precision_loss_is_bounded(self):
|
||||
"""bf16→fp16→bf16 round-trip error should be small."""
|
||||
original = torch.randn(256, 256, dtype=torch.bfloat16)
|
||||
http_entry = encode_for_http("w", original)
|
||||
_, at_serve = decode_from_http(http_entry)
|
||||
ipc_entry = encode_for_ipc("w", at_serve)
|
||||
_, at_worker = decode_from_ipc(ipc_entry)
|
||||
|
||||
max_err = (at_worker.float() - original.float()).abs().max().item()
|
||||
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
|
||||
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
|
||||
|
||||
def test_bfloat16_numpy_would_crash_without_fix(self):
|
||||
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
|
||||
t = torch.randn(4, 4, dtype=torch.bfloat16)
|
||||
with pytest.raises((RuntimeError, TypeError)):
|
||||
t.numpy()
|
||||
Reference in New Issue
Block a user