Files
axolotl/tests/test_http_weight_sync.py
Wing Lian c50c4acbf4 EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip

* fixes

* more fixeS

* add missing strided module

* ebft fixes for multi-turn

* make ebft work with async

* add example for ebft w qwen3.5

* fix for split thinking and update yaml for lora over linear attention only

* enforce_eager for vllm arg in schema

* fix sync weights

* fix multi-gpu

* handle updated sig for mm

* ddp fixes

* improve multi-gpu handling, don't calculate logits, adaptive completion length

* chore: lint

* chore: lint

* support completion_mean

* Address corereview feedback

* clamp min IS ratio

* Address PR code review

* more fixes identified

* address code review

* Fix property from rebase conflict
2026-03-24 18:43:46 -04:00

159 lines
6.3 KiB
Python

"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
"""
import pytest
import torch
from axolotl.utils.weight_serde import (
decode_from_http,
decode_from_ipc,
encode_for_http,
encode_for_ipc,
)
# ---------------------------------------------------------------------------
# Stage 1: trainer → serve endpoint (HTTP with base64)
# ---------------------------------------------------------------------------
class TestHttpEncodeRoundTrip:
"""Test encode_for_http / decode_from_http."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_http("layer.weight", original)
name, decoded = decode_from_http(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
# bf16→fp16→bf16 loses some precision
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_wire_format_is_fp16(self):
"""bf16 tensors should be sent as fp16 on the wire."""
import base64
original = torch.randn(8, 16, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
raw = base64.b64decode(entry["data"])
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
assert len(raw) == 8 * 16 * 2
# dtype field should preserve original dtype for reconstruction
assert entry["dtype"] == "torch.bfloat16"
def test_multidimensional_shapes(self):
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
original = torch.randn(*shape, dtype=torch.bfloat16)
entry = encode_for_http("w", original)
_, decoded = decode_from_http(entry)
assert decoded.shape == original.shape
assert decoded.dtype == torch.bfloat16
# ---------------------------------------------------------------------------
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
# ---------------------------------------------------------------------------
class TestIpcEncodeRoundTrip:
"""Test encode_for_ipc / decode_from_ipc."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_round_trip_dtype(self, dtype):
original = torch.randn(32, 64, dtype=dtype)
entry = encode_for_ipc("layer.weight", original)
name, decoded = decode_from_ipc(entry)
assert name == "layer.weight"
assert decoded.dtype == dtype
assert decoded.shape == original.shape
if dtype == torch.bfloat16:
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
def test_bfloat16_ipc_wire_is_fp16(self):
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
original = torch.randn(4, 8, dtype=torch.bfloat16)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float16"
assert entry["target_dtype"] == "bfloat16"
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
def test_fp32_has_no_target_dtype_mismatch(self):
original = torch.randn(4, 8, dtype=torch.float32)
entry = encode_for_ipc("w", original)
assert entry["dtype"] == "float32"
assert entry["target_dtype"] == "float32"
def test_worker_handles_missing_target_dtype(self):
"""Backward compat: older serve code may not send target_dtype."""
entry = {
"name": "w",
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
"dtype": "float32",
"shape": [4, 8],
# no target_dtype key
}
name, decoded = decode_from_ipc(entry)
assert decoded.dtype == torch.float32
assert decoded.shape == (4, 8)
# ---------------------------------------------------------------------------
# Full pipeline: trainer → serve → worker
# ---------------------------------------------------------------------------
class TestFullPipelineRoundTrip:
"""End-to-end: trainer → serve → worker."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_three_stage_round_trip(self, dtype):
"""Tensor survives trainer→serve→worker with correct dtype and values."""
original = torch.randn(16, 32, dtype=dtype)
# Stage 1: trainer encodes for HTTP
http_entry = encode_for_http("model.layers.0.weight", original)
# Stage 2: serve decodes HTTP, re-encodes for IPC
name, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc(name, at_serve)
# Stage 3: worker decodes IPC
_, at_worker = decode_from_ipc(ipc_entry)
assert at_worker.dtype == dtype
assert at_worker.shape == original.shape
if dtype == torch.bfloat16:
# Two bf16→fp16→bf16 hops compound precision loss slightly
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
else:
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
def test_bfloat16_precision_loss_is_bounded(self):
"""bf16→fp16→bf16 round-trip error should be small."""
original = torch.randn(256, 256, dtype=torch.bfloat16)
http_entry = encode_for_http("w", original)
_, at_serve = decode_from_http(http_entry)
ipc_entry = encode_for_ipc("w", at_serve)
_, at_worker = decode_from_ipc(ipc_entry)
max_err = (at_worker.float() - original.float()).abs().max().item()
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
def test_bfloat16_numpy_would_crash_without_fix(self):
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
t = torch.randn(4, 4, dtype=torch.bfloat16)
with pytest.raises((RuntimeError, TypeError)):
t.numpy()