fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

This commit is contained in:
Wing Lian
2026-04-25 08:53:28 +00:00
parent aeca18a8b0
commit 6886def92c
7 changed files with 28 additions and 18 deletions

View File

@@ -47,7 +47,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
attn_implementation: flash_attention_2
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -43,7 +43,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
attn_implementation: flash_attention_2
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -44,7 +44,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
attn_implementation: flash_attention_2
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -43,7 +43,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
attn_implementation: flash_attention_2
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -56,7 +56,6 @@ learning_rate: 2e-4
bf16: true
tf32: true
attn_implementation: flash_attention_2
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -56,7 +56,6 @@ learning_rate: 2e-4
bf16: true
tf32: true
attn_implementation: flash_attention_2
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -9,6 +9,9 @@ Covers the Phase 1 contract:
Plus Phase 2 gap fixes and full-model validation behaviour.
"""
import logging
from contextlib import contextmanager
import pytest
from axolotl.utils.config import validate_config
@@ -22,6 +25,25 @@ from axolotl.utils.schemas.enums import (
)
@contextmanager
def _capture_axolotl_warnings(caplog):
"""Capture WARNINGs from `axolotl.*` loggers via caplog.
`axolotl.cli` calls `configure_logging()` at import time, which sets
`propagate=False` on the `axolotl` logger so records do not reach the root
logger that pytest's `caplog` hooks. This helper temporarily re-enables
propagation for the duration of the block.
"""
ax_logger = logging.getLogger("axolotl")
old_propagate = ax_logger.propagate
ax_logger.propagate = True
try:
with caplog.at_level(logging.WARNING, logger="axolotl"):
yield
finally:
ax_logger.propagate = old_propagate
class TestNormalizerLegacyMapping:
"""Legacy boolean flags map to canonical attn_implementation."""
@@ -361,41 +383,35 @@ class TestPhase2GapFixes:
"""Regression tests for the validator gaps closed in Phase 2."""
def test_sample_packing_with_eager_warns(self, min_base_cfg, caplog):
import logging
cfg = min_base_cfg | DictDefault(
attn_implementation="eager", sample_packing=True
)
with caplog.at_level(logging.WARNING):
with _capture_axolotl_warnings(caplog):
validate_config(cfg)
assert any(
"does not handle cross-sample decontamination" in r.message
"does not handle cross-sample decontamination" in r.getMessage()
for r in caplog.records
)
def test_sample_packing_with_sdpa_warns(self, min_base_cfg, caplog):
import logging
cfg = min_base_cfg | DictDefault(
attn_implementation="sdpa", sample_packing=True
)
with caplog.at_level(logging.WARNING):
with _capture_axolotl_warnings(caplog):
validate_config(cfg)
assert any(
"does not handle cross-sample decontamination" in r.message
"does not handle cross-sample decontamination" in r.getMessage()
for r in caplog.records
)
def test_sample_packing_with_flash_does_not_warn(self, min_base_cfg, caplog):
import logging
cfg = min_base_cfg | DictDefault(
attn_implementation="flash_attention_2", sample_packing=True
)
with caplog.at_level(logging.WARNING):
with _capture_axolotl_warnings(caplog):
validate_config(cfg)
assert not any(
"does not handle cross-sample decontamination" in r.message
"does not handle cross-sample decontamination" in r.getMessage()
for r in caplog.records
)