fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests
This commit is contained in:
@@ -47,7 +47,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
attn_implementation: flash_attention_2
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -43,7 +43,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
attn_implementation: flash_attention_2
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -44,7 +44,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
attn_implementation: flash_attention_2
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -43,7 +43,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
attn_implementation: flash_attention_2
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -56,7 +56,6 @@ learning_rate: 2e-4
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
attn_implementation: flash_attention_2
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -56,7 +56,6 @@ learning_rate: 2e-4
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
attn_implementation: flash_attention_2
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -9,6 +9,9 @@ Covers the Phase 1 contract:
|
||||
Plus Phase 2 gap fixes and full-model validation behaviour.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
@@ -22,6 +25,25 @@ from axolotl.utils.schemas.enums import (
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _capture_axolotl_warnings(caplog):
|
||||
"""Capture WARNINGs from `axolotl.*` loggers via caplog.
|
||||
|
||||
`axolotl.cli` calls `configure_logging()` at import time, which sets
|
||||
`propagate=False` on the `axolotl` logger so records do not reach the root
|
||||
logger that pytest's `caplog` hooks. This helper temporarily re-enables
|
||||
propagation for the duration of the block.
|
||||
"""
|
||||
ax_logger = logging.getLogger("axolotl")
|
||||
old_propagate = ax_logger.propagate
|
||||
ax_logger.propagate = True
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING, logger="axolotl"):
|
||||
yield
|
||||
finally:
|
||||
ax_logger.propagate = old_propagate
|
||||
|
||||
|
||||
class TestNormalizerLegacyMapping:
|
||||
"""Legacy boolean flags map to canonical attn_implementation."""
|
||||
|
||||
@@ -361,41 +383,35 @@ class TestPhase2GapFixes:
|
||||
"""Regression tests for the validator gaps closed in Phase 2."""
|
||||
|
||||
def test_sample_packing_with_eager_warns(self, min_base_cfg, caplog):
|
||||
import logging
|
||||
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
attn_implementation="eager", sample_packing=True
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with _capture_axolotl_warnings(caplog):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"does not handle cross-sample decontamination" in r.message
|
||||
"does not handle cross-sample decontamination" in r.getMessage()
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
def test_sample_packing_with_sdpa_warns(self, min_base_cfg, caplog):
|
||||
import logging
|
||||
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
attn_implementation="sdpa", sample_packing=True
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with _capture_axolotl_warnings(caplog):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"does not handle cross-sample decontamination" in r.message
|
||||
"does not handle cross-sample decontamination" in r.getMessage()
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
def test_sample_packing_with_flash_does_not_warn(self, min_base_cfg, caplog):
|
||||
import logging
|
||||
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
attn_implementation="flash_attention_2", sample_packing=True
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with _capture_axolotl_warnings(caplog):
|
||||
validate_config(cfg)
|
||||
assert not any(
|
||||
"does not handle cross-sample decontamination" in r.message
|
||||
"does not handle cross-sample decontamination" in r.getMessage()
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user