EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]

* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
This commit is contained in:
Wing Lian
2026-03-24 18:43:46 -04:00
committed by GitHub
parent e9883c91d4
commit c50c4acbf4
48 changed files with 5885 additions and 168 deletions

294
tests/test_ebft_kernels.py Normal file
View File

@@ -0,0 +1,294 @@
"""Correctness tests for fused EBFT Triton kernels."""
import pytest
import torch
import torch.nn.functional as F
from axolotl.core.trainers.ebft.kernels import (
fused_cosine_similarity,
fused_diversity_penalty,
fused_log_softmax_gather,
fused_reinforce_loss,
)
# Skip all tests if CUDA not available
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
)
DEVICE = "cuda"
# ---------------------------------------------------------------------------
# 1. fused_log_softmax_gather
# ---------------------------------------------------------------------------
class TestFusedLogSoftmaxGather:
def _reference(self, logits, labels):
"""PyTorch reference: log_softmax + gather."""
lp = F.log_softmax(logits.float(), dim=-1)
return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
def test_basic_correctness(self):
B, S, V = 2, 16, 1024
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_large_vocab(self):
"""Test with realistic vocab size (128K)."""
B, S, V = 1, 8, 128256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_fp32_input(self):
B, S, V = 2, 8, 512
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
def test_output_is_negative(self):
"""log_softmax values should always be <= 0."""
B, S, V = 4, 32, 2048
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
out = fused_log_softmax_gather(logits, labels)
assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
def test_extreme_logits(self):
"""Test numerical stability with very large/small logits."""
B, S, V = 1, 4, 256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
logits[:, 0, :] = 1000.0 # very large
logits[:, 1, :] = -1000.0 # very small
logits[:, 2, 0] = 1000.0 # one hot-ish
labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
assert torch.isfinite(out).all(), "Should handle extreme values"
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with pre-flattened (N, V) input."""
N, V = 64, 4096
logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (N,), device=DEVICE)
ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 2. fused_reinforce_loss
# ---------------------------------------------------------------------------
class TestFusedReinforceLoss:
def _reference(self, logps, advantages, mask):
"""PyTorch reference implementation."""
loss_per_token = -logps * advantages
return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
def test_basic_correctness(self):
N = 1024
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with (B, S) shaped inputs."""
B, S = 4, 256
logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_all_masked(self):
"""All-zero mask should return 0."""
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
out = fused_reinforce_loss(logps, advs, mask)
assert out.item() == 0.0
def test_all_unmasked(self):
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic size (4 * 3000 tokens)."""
N = 12000
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 3. fused_cosine_similarity
# ---------------------------------------------------------------------------
class TestFusedCosineSimilarity:
def test_basic_correctness(self):
N, D = 64, 256
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_batched(self):
"""Test with (B, N, NB, D) shaped input."""
B, N, NB, D = 2, 4, 16, 512
a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_identical_vectors(self):
"""Identical vectors should give similarity = 1."""
N, D = 16, 128
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, a)
torch.testing.assert_close(
out,
torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_orthogonal_vectors(self):
"""Orthogonal vectors should give similarity = 0."""
D = 128
a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
a[0, 0] = 1.0
b[0, 1] = 1.0
out = fused_cosine_similarity(a, b)
assert abs(out.item()) < 1e-5
def test_opposite_vectors(self):
"""Opposite vectors should give similarity = -1."""
N, D = 8, 64
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, -a)
torch.testing.assert_close(
out,
-torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_large_dimension(self):
"""Test with large feature dimension (multi-layer concatenated features)."""
N, D = 32, 4608 # 3 layers * 1536 hidden
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
# ---------------------------------------------------------------------------
# 4. fused_diversity_penalty
# ---------------------------------------------------------------------------
class TestFusedDiversityPenalty:
def _reference(self, embeddings):
"""PyTorch reference: bmm + mask diagonal + mean."""
B, N, D = embeddings.shape
sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
return sims.sum(dim=-1) / (N - 1)
def test_basic_correctness(self):
B, N, D = 4, 4, 256
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_two_samples(self):
"""With n=2, diversity = dot(a, b) for each."""
B, D = 3, 128
emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_identical_samples(self):
"""All identical samples should give max diversity."""
B, N, D = 2, 4, 64
vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
emb = vec.expand(B, N, D).contiguous()
out = fused_diversity_penalty(emb)
# Should be ||vec||^2 for each (self-excluded mean of identical dot products)
expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic EBFT dimensions."""
B, N, D = 1, 4, 4608 # multi-layer features
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
def test_single_sample_returns_zeros(self):
"""N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
B, D = 3, 128
emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
out = fused_diversity_penalty(emb)
assert (out == 0).all(), "N=1 diversity should be exactly zero"

View File

@@ -0,0 +1,363 @@
"""Tests for the EBFT strided structured dataset transform and data loading."""
import pytest
from datasets import Dataset
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import PreTrainedTokenizerFast
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
@pytest.fixture
def tokenizer():
"""Create a simple word-level tokenizer — no network access needed."""
# Build a tiny vocab covering common test words
vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3}
words = (
"what is 2 + the answer 4 hello world goodbye bye hi short prompt "
"x write code print test some string metadata noise ok good python "
"sampling abc 123 def solve return this that"
).split()
for w in words:
if w not in vocab:
vocab[w] = len(vocab)
backend = Tokenizer(models.WordLevel(vocab=vocab, unk_token="[UNK]"))
backend.pre_tokenizer = pre_tokenizers.Whitespace()
tok = PreTrainedTokenizerFast(
tokenizer_object=backend,
bos_token="[BOS]",
eos_token="[EOS]",
pad_token="[PAD]",
unk_token="[UNK]",
)
return tok
@pytest.fixture
def cfg():
return DictDefault({"sequence_len": 64})
@pytest.fixture
def transform_fn_and_kwargs(cfg):
result = load_ebft("ebft_strided_structured.transform", cfg)
assert result is not None, "Failed to load ebft_strided_structured transform"
transform_fn, map_kwargs = result
return transform_fn, map_kwargs
class TestEBFTStridedStructuredTransform:
"""Tests for the dataset transform function itself."""
def test_transform_loads(self, transform_fn_and_kwargs):
transform_fn, map_kwargs = transform_fn_and_kwargs
assert callable(transform_fn)
assert "remove_columns" in map_kwargs
def test_remove_columns_sentinel(self, transform_fn_and_kwargs):
"""Transform should signal removal of all original columns."""
_, map_kwargs = transform_fn_and_kwargs
assert map_kwargs["remove_columns"] == "__all__"
def test_prompt_completion_tokenization(self, transform_fn_and_kwargs, tokenizer):
"""Prompt tokens get labels=-100, completion tokens get real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "what is 2 + 2", "output": "the answer is 4"}
result = transform_fn(example, tokenizer=tokenizer)
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "prompt_length" in result
prompt_length = result["prompt_length"]
labels = result["labels"]
seq_len = len(result["input_ids"])
assert seq_len == 64, "Should be padded to sequence_len"
assert len(labels) == seq_len
assert prompt_length > 0
# Prompt tokens should be masked
for i in range(prompt_length):
assert labels[i] == -100, f"Prompt token at {i} should be -100"
# At least one completion token should have a real label
completion_labels = [lab for lab in labels[prompt_length:] if lab != -100]
assert len(completion_labels) > 0, "Should have non-masked completion tokens"
def test_prompt_length_matches_boundary(self, transform_fn_and_kwargs, tokenizer):
"""prompt_length should be the exact boundary between -100 and real labels."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hello world", "output": "goodbye world"}
result = transform_fn(example, tokenizer=tokenizer)
prompt_length = result["prompt_length"]
labels = result["labels"]
assert labels[prompt_length - 1] == -100, "Last prompt token should be masked"
assert labels[prompt_length] != -100, (
"First completion token should not be masked"
)
def test_padding_tokens_masked(self, transform_fn_and_kwargs, tokenizer):
"""Padding tokens should have labels=-100 and attention_mask=0."""
transform_fn, _ = transform_fn_and_kwargs
example = {"input": "hi", "output": "bye"}
result = transform_fn(example, tokenizer=tokenizer)
attention_mask = result["attention_mask"]
labels = result["labels"]
real_len = sum(attention_mask)
assert real_len < 64, "Short example should have padding"
for i in range(real_len, 64):
assert attention_mask[i] == 0, (
f"Pad position {i} should have attention_mask=0"
)
assert labels[i] == -100, f"Pad position {i} should have labels=-100"
def test_truncation_long_completion(self, transform_fn_and_kwargs, tokenizer):
"""Long completions should be truncated to fit sequence_len."""
transform_fn, _ = transform_fn_and_kwargs
example = {
"input": "short prompt",
"output": "x " * 500,
}
result = transform_fn(example, tokenizer=tokenizer)
assert len(result["input_ids"]) == 64
def test_alternative_field_names(self, transform_fn_and_kwargs, tokenizer):
"""Transform should handle different field name conventions."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn(
{"prompt": "what", "completion": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
result = transform_fn(
{"question": "what", "answer": "this"}, tokenizer=tokenizer
)
assert result["prompt_length"] > 0
def test_without_tokenizer_returns_prompt(self, transform_fn_and_kwargs):
"""Without tokenizer, should return a prompt string."""
transform_fn, _ = transform_fn_and_kwargs
result = transform_fn({"input": "hello", "output": "world"})
assert "prompt" in result
assert result["prompt"] == "hello"
class TestEBFTColumnRemoval:
"""Tests for the __all__ column removal logic in the RL data path."""
def _filter_remove_columns(self, map_kwargs, dataset):
"""Reproduce the filtering logic from rl.py _load_split."""
if "remove_columns" in map_kwargs:
ds_columns = dataset.column_names
if map_kwargs["remove_columns"] == "__all__":
map_kwargs["remove_columns"] = list(ds_columns)
else:
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
return map_kwargs
def test_all_original_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""After mapping, only tokenized columns should remain."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs) # copy
ds = Dataset.from_list(
[
{"input": "what is 2 + 2", "output": "4", "extra_field": "noise"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert "input" in map_kwargs["remove_columns"]
assert "output" in map_kwargs["remove_columns"]
assert "extra_field" in map_kwargs["remove_columns"]
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
assert "input_ids" in mapped.column_names
assert "labels" in mapped.column_names
assert "prompt_length" in mapped.column_names
assert "input" not in mapped.column_names
assert "output" not in mapped.column_names
assert "extra_field" not in mapped.column_names
def test_extra_metadata_columns_removed(self, transform_fn_and_kwargs, tokenizer):
"""Datasets with many extra metadata columns should all be cleaned up."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{
"input": "write hello world",
"output": "print hello",
"id": "abc 123",
"domain": "python",
"generation_algorithm": "sampling",
"llm_judgement": "good",
"unit_tests": "test",
"tests_execution_status": "ok",
"average_test_score": 0.95,
},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
assert len(map_kwargs["remove_columns"]) == 9
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
expected_columns = {"input_ids", "attention_mask", "labels", "prompt_length"}
assert set(mapped.column_names) == expected_columns
def test_no_string_columns_remain(self, transform_fn_and_kwargs, tokenizer):
"""No string-typed columns should remain (would crash the DataLoader)."""
transform_fn, map_kwargs = transform_fn_and_kwargs
map_kwargs = dict(map_kwargs)
ds = Dataset.from_list(
[
{"input": "test", "output": "test", "notes": "some string metadata"},
]
)
map_kwargs = self._filter_remove_columns(map_kwargs, ds)
from functools import partial
mapped = ds.map(partial(transform_fn, tokenizer=tokenizer), **map_kwargs)
for col in mapped.column_names:
feat = mapped.features[col]
assert str(feat) != "string", (
f"Column '{col}' is still a string — would crash DataLoader"
)
def test_filter_preserves_explicit_list(self):
"""When remove_columns is an explicit list, only existing columns are kept."""
ds = Dataset.from_list([{"a": 1, "b": "text", "c": 3}])
map_kwargs = {"remove_columns": ["a", "b", "missing_col"]}
ds_columns = ds.column_names
map_kwargs["remove_columns"] = [
c for c in map_kwargs["remove_columns"] if c in ds_columns
]
assert map_kwargs["remove_columns"] == ["a", "b"]
assert "missing_col" not in map_kwargs["remove_columns"]
class TestMultiTurnSeparators:
"""Verify multi-turn transforms and trainer-side GT reconstruction."""
def test_multiturn_transform_splits_turns(self):
"""Transform should store first turn as GT and remaining turns separately."""
from axolotl.prompt_strategies.ebft import load as load_ebft
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"sequence_len": 512})
fn, _ = load_ebft("ebft_chat_multiturn.transform", cfg)
out = fn(
{
"messages": [
{"role": "user", "content": "Q1"},
{"role": "assistant", "content": "A1"},
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
}
)
# ground_truth is only the first assistant turn
assert out["ground_truth"] == "A1"
# remaining_turns carries the rest for trainer-side reconstruction
assert out["remaining_turns"] == [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
def test_multiturn_gt_reconstruction_via_chat_template(self):
"""Trainer-side GT reconstruction should insert role markers between turns.
This tests the logic from trainer.py:284-299 that reconstructs multi-turn
GT using apply_chat_template, ensuring assistant turns are separated by
role markers rather than concatenated as raw text.
"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
# Simulate the transform output
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
remaining_turns = [
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "A2"},
]
# --- Reproduce the trainer-side reconstruction (trainer.py:284-299) ---
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
gt_conv = list(prompt_msgs) + [{"role": "assistant", "content": gt}]
gt_conv.extend(remaining_turns)
full_gt_text = tokenizer.apply_chat_template(
gt_conv, tokenize=False, add_generation_prompt=False
)
# The full GT text should contain both assistant turns with role markers
assert "A1" in full_gt_text
assert "A2" in full_gt_text
# Raw concatenation "A1A2" should NOT appear — role markers separate them
assert "A1A2" not in full_gt_text, (
"GT reconstruction should have role markers between turns, not raw concatenation"
)
# The user turn Q2 should appear between A1 and A2
a1_pos = full_gt_text.index("A1")
a2_pos = full_gt_text.index("A2")
q2_pos = full_gt_text.index("Q2")
assert a1_pos < q2_pos < a2_pos, (
"Turn order should be A1 -> Q2 -> A2 in rendered GT"
)
# The GT should start with the prompt
assert full_gt_text.startswith(prompt_text), (
"Full GT should start with the rendered prompt"
)
def test_multiturn_gt_reconstruction_fallback_single_turn(self):
"""Single-turn prompts in a multi-turn dataset should use raw concatenation."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True
)
prompt_msgs = [{"role": "user", "content": "Q1"}]
gt = "A1"
# remaining_turns would be [] for single-turn prompts
prompt_text = tokenizer.apply_chat_template(
prompt_msgs, tokenize=False, add_generation_prompt=True
)
# With empty remaining_turns, trainer falls through to raw concat
# (trainer.py:302: gt_texts.append(prompt_text + gt))
gt_text = prompt_text + gt
assert gt_text.endswith("A1")
assert prompt_text in gt_text

View File

@@ -0,0 +1,158 @@
"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
"""
import pytest
import torch
from axolotl.utils.weight_serde import (
decode_from_http,
decode_from_ipc,
encode_for_http,
encode_for_ipc,
)
# ---------------------------------------------------------------------------
# Stage 1: trainer → serve endpoint (HTTP with base64)
# ---------------------------------------------------------------------------
class TestHttpEncodeRoundTrip:
"""Test encode_for_http / decode_from_http."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_http("layer.weight", original)
name, decoded = decode_from_http(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
# bf16→fp16→bf16 loses some precision
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_wire_format_is_fp16(self):
"""bf16 tensors should be sent as fp16 on the wire."""
import base64
original = torch.randn(8, 16, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
raw = base64.b64decode(entry["data"])
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
assert len(raw) == 8 * 16 * 2
# dtype field should preserve original dtype for reconstruction
assert entry["dtype"] == "torch.bfloat16"
def test_multidimensional_shapes(self):
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
original = torch.randn(*shape, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
_, decoded = decode_from_http(entry)
assert decoded.shape == original.shape
assert decoded.dtype == torch.bfloat16
# ---------------------------------------------------------------------------
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
# ---------------------------------------------------------------------------
class TestIpcEncodeRoundTrip:
"""Test encode_for_ipc / decode_from_ipc."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_ipc("layer.weight", original)
name, decoded = decode_from_ipc(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_ipc_wire_is_fp16(self):
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
original = torch.randn(4, 8, dtype=torch.bfloat16)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float16"
assert entry["target_dtype"] == "bfloat16"
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
def test_fp32_has_no_target_dtype_mismatch(self):
original = torch.randn(4, 8, dtype=torch.float32)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float32"
assert entry["target_dtype"] == "float32"
def test_worker_handles_missing_target_dtype(self):
"""Backward compat: older serve code may not send target_dtype."""
entry = {
"name": "w",
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
"dtype": "float32",
"shape": [4, 8],
# no target_dtype key
}
name, decoded = decode_from_ipc(entry)
assert decoded.dtype == torch.float32
assert decoded.shape == (4, 8)
# ---------------------------------------------------------------------------
# Full pipeline: trainer → serve → worker
# ---------------------------------------------------------------------------
class TestFullPipelineRoundTrip:
"""End-to-end: trainer → serve → worker."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_three_stage_round_trip(self, dtype):
"""Tensor survives trainer→serve→worker with correct dtype and values."""
original = torch.randn(16, 32, dtype=dtype)
# Stage 1: trainer encodes for HTTP
http_entry = encode_for_http("model.layers.0.weight", original)
# Stage 2: serve decodes HTTP, re-encodes for IPC
name, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc(name, at_serve)
# Stage 3: worker decodes IPC
_, at_worker = decode_from_ipc(ipc_entry)
assert at_worker.dtype == dtype
assert at_worker.shape == original.shape
if dtype == torch.bfloat16:
# Two bf16→fp16→bf16 hops compound precision loss slightly
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
else:
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
def test_bfloat16_precision_loss_is_bounded(self):
"""bf16→fp16→bf16 round-trip error should be small."""
original = torch.randn(256, 256, dtype=torch.bfloat16)
http_entry = encode_for_http("w", original)
_, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc("w", at_serve)
_, at_worker = decode_from_ipc(ipc_entry)
max_err = (at_worker.float() - original.float()).abs().max().item()
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
def test_bfloat16_numpy_would_crash_without_fix(self):
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
t = torch.randn(4, 4, dtype=torch.bfloat16)
with pytest.raises((RuntimeError, TypeError)):
t.numpy()