Files
axolotl/tests/test_ebft_kernels.py
Wing Lian c50c4acbf4 EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
2026-03-24 18:43:46 -04:00

295 lines
11 KiB
Python

"""Correctness tests for fused EBFT Triton kernels."""
import pytest
import torch
import torch.nn.functional as F
from axolotl.core.trainers.ebft.kernels import (
fused_cosine_similarity,
fused_diversity_penalty,
fused_log_softmax_gather,
fused_reinforce_loss,
)
# Skip all tests if CUDA not available
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
)
DEVICE = "cuda"
# ---------------------------------------------------------------------------
# 1. fused_log_softmax_gather
# ---------------------------------------------------------------------------
class TestFusedLogSoftmaxGather:
def _reference(self, logits, labels):
"""PyTorch reference: log_softmax + gather."""
lp = F.log_softmax(logits.float(), dim=-1)
return lp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
def test_basic_correctness(self):
B, S, V = 2, 16, 1024
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_large_vocab(self):
"""Test with realistic vocab size (128K)."""
B, S, V = 1, 8, 128256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_fp32_input(self):
B, S, V = 2, 8, 512
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
labels = torch.randint(0, V, (B, S), device=DEVICE)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
def test_output_is_negative(self):
"""log_softmax values should always be <= 0."""
B, S, V = 4, 32, 2048
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (B, S), device=DEVICE)
out = fused_log_softmax_gather(logits, labels)
assert (out <= 1e-5).all(), "log_softmax values should be <= 0"
def test_extreme_logits(self):
"""Test numerical stability with very large/small logits."""
B, S, V = 1, 4, 256
logits = torch.randn(B, S, V, device=DEVICE, dtype=torch.float32)
logits[:, 0, :] = 1000.0 # very large
logits[:, 1, :] = -1000.0 # very small
logits[:, 2, 0] = 1000.0 # one hot-ish
labels = torch.zeros(B, S, device=DEVICE, dtype=torch.long)
ref = self._reference(logits, labels)
out = fused_log_softmax_gather(logits, labels)
assert torch.isfinite(out).all(), "Should handle extreme values"
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with pre-flattened (N, V) input."""
N, V = 64, 4096
logits = torch.randn(N, V, device=DEVICE, dtype=torch.bfloat16)
labels = torch.randint(0, V, (N,), device=DEVICE)
ref = self._reference(logits.unsqueeze(0), labels.unsqueeze(0)).squeeze(0)
out = fused_log_softmax_gather(logits, labels)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 2. fused_reinforce_loss
# ---------------------------------------------------------------------------
class TestFusedReinforceLoss:
def _reference(self, logps, advantages, mask):
"""PyTorch reference implementation."""
loss_per_token = -logps * advantages
return (loss_per_token * mask.float()).sum() / mask.float().sum().clamp(min=1)
def test_basic_correctness(self):
N = 1024
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
"""Test with (B, S) shaped inputs."""
B, S = 4, 256
logps = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
advs = torch.randn(B, S, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (B, S), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_all_masked(self):
"""All-zero mask should return 0."""
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.zeros(N, device=DEVICE, dtype=torch.bool)
out = fused_reinforce_loss(logps, advs, mask)
assert out.item() == 0.0
def test_all_unmasked(self):
N = 512
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.ones(N, device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic size (4 * 3000 tokens)."""
N = 12000
logps = torch.randn(N, device=DEVICE, dtype=torch.float32)
advs = torch.randn(N, device=DEVICE, dtype=torch.float32)
mask = torch.randint(0, 2, (N,), device=DEVICE, dtype=torch.bool)
ref = self._reference(logps, advs, mask)
out = fused_reinforce_loss(logps, advs, mask)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# 3. fused_cosine_similarity
# ---------------------------------------------------------------------------
class TestFusedCosineSimilarity:
def test_basic_correctness(self):
N, D = 64, 256
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)
def test_batched(self):
"""Test with (B, N, NB, D) shaped input."""
B, N, NB, D = 2, 4, 16, 512
a = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(B, N, NB, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_identical_vectors(self):
"""Identical vectors should give similarity = 1."""
N, D = 16, 128
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, a)
torch.testing.assert_close(
out,
torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_orthogonal_vectors(self):
"""Orthogonal vectors should give similarity = 0."""
D = 128
a = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
b = torch.zeros(1, D, device=DEVICE, dtype=torch.float32)
a[0, 0] = 1.0
b[0, 1] = 1.0
out = fused_cosine_similarity(a, b)
assert abs(out.item()) < 1e-5
def test_opposite_vectors(self):
"""Opposite vectors should give similarity = -1."""
N, D = 8, 64
a = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
out = fused_cosine_similarity(a, -a)
torch.testing.assert_close(
out,
-torch.ones(N, device=DEVICE, dtype=torch.float32),
atol=1e-5,
rtol=1e-5,
)
def test_large_dimension(self):
"""Test with large feature dimension (multi-layer concatenated features)."""
N, D = 32, 4608 # 3 layers * 1536 hidden
a = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(N, D, device=DEVICE, dtype=torch.bfloat16)
ref = F.cosine_similarity(a.float(), b.float(), dim=-1)
out = fused_cosine_similarity(a, b)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
# ---------------------------------------------------------------------------
# 4. fused_diversity_penalty
# ---------------------------------------------------------------------------
class TestFusedDiversityPenalty:
def _reference(self, embeddings):
"""PyTorch reference: bmm + mask diagonal + mean."""
B, N, D = embeddings.shape
sims = torch.bmm(embeddings.float(), embeddings.float().transpose(1, 2))
eye = torch.eye(N, device=embeddings.device, dtype=torch.bool)
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
return sims.sum(dim=-1) / (N - 1)
def test_basic_correctness(self):
B, N, D = 4, 4, 256
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
def test_two_samples(self):
"""With n=2, diversity = dot(a, b) for each."""
B, D = 3, 128
emb = torch.randn(B, 2, D, device=DEVICE, dtype=torch.float32)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=1e-4, rtol=1e-4)
def test_identical_samples(self):
"""All identical samples should give max diversity."""
B, N, D = 2, 4, 64
vec = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
emb = vec.expand(B, N, D).contiguous()
out = fused_diversity_penalty(emb)
# Should be ||vec||^2 for each (self-excluded mean of identical dot products)
expected = (vec.squeeze(1) ** 2).sum(dim=-1, keepdim=True).expand(B, N)
torch.testing.assert_close(out, expected, atol=1e-4, rtol=1e-4)
def test_large(self):
"""Test with realistic EBFT dimensions."""
B, N, D = 1, 4, 4608 # multi-layer features
emb = torch.randn(B, N, D, device=DEVICE, dtype=torch.bfloat16)
ref = self._reference(emb)
out = fused_diversity_penalty(emb)
torch.testing.assert_close(out, ref, atol=5e-2, rtol=5e-2)
def test_single_sample_returns_zeros(self):
"""N=1 should return zeros (no pairs), not garbage from uninitialized memory."""
B, D = 3, 128
emb = torch.randn(B, 1, D, device=DEVICE, dtype=torch.float32)
out = fused_diversity_penalty(emb)
assert (out == 0).all(), "N=1 diversity should be exactly zero"