* EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict
159 lines
6.3 KiB
Python
159 lines
6.3 KiB
Python
"""Tests for HTTP weight sync serialization round-trip (bf16/fp16/fp32).
|
|
|
|
Exercises the encode/decode helpers in axolotl.utils.weight_serde that handle
|
|
the three-stage weight transfer: trainer → serve endpoint → vLLM worker.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from axolotl.utils.weight_serde import (
|
|
decode_from_http,
|
|
decode_from_ipc,
|
|
encode_for_http,
|
|
encode_for_ipc,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stage 1: trainer → serve endpoint (HTTP with base64)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHttpEncodeRoundTrip:
|
|
"""Test encode_for_http / decode_from_http."""
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
|
def test_round_trip_dtype(self, dtype):
|
|
original = torch.randn(32, 64, dtype=dtype)
|
|
entry = encode_for_http("layer.weight", original)
|
|
name, decoded = decode_from_http(entry)
|
|
|
|
assert name == "layer.weight"
|
|
assert decoded.dtype == dtype
|
|
assert decoded.shape == original.shape
|
|
if dtype == torch.bfloat16:
|
|
# bf16→fp16→bf16 loses some precision
|
|
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
|
|
else:
|
|
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
|
|
|
|
def test_bfloat16_wire_format_is_fp16(self):
|
|
"""bf16 tensors should be sent as fp16 on the wire."""
|
|
import base64
|
|
|
|
original = torch.randn(8, 16, dtype=torch.bfloat16)
|
|
entry = encode_for_http("w", original)
|
|
raw = base64.b64decode(entry["data"])
|
|
# 8*16 elements * 2 bytes/elem (fp16) = 256 bytes
|
|
assert len(raw) == 8 * 16 * 2
|
|
# dtype field should preserve original dtype for reconstruction
|
|
assert entry["dtype"] == "torch.bfloat16"
|
|
|
|
def test_multidimensional_shapes(self):
|
|
for shape in [(128,), (4, 32), (2, 3, 16), (2, 2, 2, 8)]:
|
|
original = torch.randn(*shape, dtype=torch.bfloat16)
|
|
entry = encode_for_http("w", original)
|
|
_, decoded = decode_from_http(entry)
|
|
assert decoded.shape == original.shape
|
|
assert decoded.dtype == torch.bfloat16
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stage 2: serve endpoint → vLLM worker (IPC with raw bytes)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestIpcEncodeRoundTrip:
|
|
"""Test encode_for_ipc / decode_from_ipc."""
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
|
def test_round_trip_dtype(self, dtype):
|
|
original = torch.randn(32, 64, dtype=dtype)
|
|
entry = encode_for_ipc("layer.weight", original)
|
|
name, decoded = decode_from_ipc(entry)
|
|
|
|
assert name == "layer.weight"
|
|
assert decoded.dtype == dtype
|
|
assert decoded.shape == original.shape
|
|
if dtype == torch.bfloat16:
|
|
torch.testing.assert_close(decoded, original, atol=1e-2, rtol=1e-2)
|
|
else:
|
|
torch.testing.assert_close(decoded, original, atol=0, rtol=0)
|
|
|
|
def test_bfloat16_ipc_wire_is_fp16(self):
|
|
"""bf16 tensors should be serialized as fp16 bytes in IPC."""
|
|
original = torch.randn(4, 8, dtype=torch.bfloat16)
|
|
entry = encode_for_ipc("w", original)
|
|
assert entry["dtype"] == "float16"
|
|
assert entry["target_dtype"] == "bfloat16"
|
|
assert len(entry["data"]) == 4 * 8 * 2 # fp16 bytes
|
|
|
|
def test_fp32_has_no_target_dtype_mismatch(self):
|
|
original = torch.randn(4, 8, dtype=torch.float32)
|
|
entry = encode_for_ipc("w", original)
|
|
assert entry["dtype"] == "float32"
|
|
assert entry["target_dtype"] == "float32"
|
|
|
|
def test_worker_handles_missing_target_dtype(self):
|
|
"""Backward compat: older serve code may not send target_dtype."""
|
|
entry = {
|
|
"name": "w",
|
|
"data": torch.randn(4, 8, dtype=torch.float32).numpy().tobytes(),
|
|
"dtype": "float32",
|
|
"shape": [4, 8],
|
|
# no target_dtype key
|
|
}
|
|
name, decoded = decode_from_ipc(entry)
|
|
assert decoded.dtype == torch.float32
|
|
assert decoded.shape == (4, 8)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Full pipeline: trainer → serve → worker
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFullPipelineRoundTrip:
|
|
"""End-to-end: trainer → serve → worker."""
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
|
def test_three_stage_round_trip(self, dtype):
|
|
"""Tensor survives trainer→serve→worker with correct dtype and values."""
|
|
original = torch.randn(16, 32, dtype=dtype)
|
|
|
|
# Stage 1: trainer encodes for HTTP
|
|
http_entry = encode_for_http("model.layers.0.weight", original)
|
|
|
|
# Stage 2: serve decodes HTTP, re-encodes for IPC
|
|
name, at_serve = decode_from_http(http_entry)
|
|
ipc_entry = encode_for_ipc(name, at_serve)
|
|
|
|
# Stage 3: worker decodes IPC
|
|
_, at_worker = decode_from_ipc(ipc_entry)
|
|
|
|
assert at_worker.dtype == dtype
|
|
assert at_worker.shape == original.shape
|
|
if dtype == torch.bfloat16:
|
|
# Two bf16→fp16→bf16 hops compound precision loss slightly
|
|
torch.testing.assert_close(at_worker, original, atol=2e-2, rtol=2e-2)
|
|
else:
|
|
torch.testing.assert_close(at_worker, original, atol=0, rtol=0)
|
|
|
|
def test_bfloat16_precision_loss_is_bounded(self):
|
|
"""bf16→fp16→bf16 round-trip error should be small."""
|
|
original = torch.randn(256, 256, dtype=torch.bfloat16)
|
|
http_entry = encode_for_http("w", original)
|
|
_, at_serve = decode_from_http(http_entry)
|
|
ipc_entry = encode_for_ipc("w", at_serve)
|
|
_, at_worker = decode_from_ipc(ipc_entry)
|
|
|
|
max_err = (at_worker.float() - original.float()).abs().max().item()
|
|
# bf16 has ~8e-3 precision, fp16 has ~1e-3; round-trip error bounded
|
|
assert max_err < 0.05, f"Max error {max_err} exceeds bound"
|
|
|
|
def test_bfloat16_numpy_would_crash_without_fix(self):
|
|
"""Verify that calling .numpy() on bf16 raises, confirming the fix is needed."""
|
|
t = torch.randn(4, 4, dtype=torch.bfloat16)
|
|
with pytest.raises((RuntimeError, TypeError)):
|
|
t.numpy()
|