fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests
This commit is contained in:
@@ -47,7 +47,6 @@ learning_rate: 2e-5
|
|||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
attn_implementation: flash_attention_2
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ learning_rate: 2e-5
|
|||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
attn_implementation: flash_attention_2
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ learning_rate: 2e-5
|
|||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
attn_implementation: flash_attention_2
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ learning_rate: 2e-5
|
|||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
attn_implementation: flash_attention_2
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ learning_rate: 2e-4
|
|||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
attn_implementation: flash_attention_2
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ learning_rate: 2e-4
|
|||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
attn_implementation: flash_attention_2
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ Covers the Phase 1 contract:
|
|||||||
Plus Phase 2 gap fixes and full-model validation behaviour.
|
Plus Phase 2 gap fixes and full-model validation behaviour.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
@@ -22,6 +25,25 @@ from axolotl.utils.schemas.enums import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _capture_axolotl_warnings(caplog):
|
||||||
|
"""Capture WARNINGs from `axolotl.*` loggers via caplog.
|
||||||
|
|
||||||
|
`axolotl.cli` calls `configure_logging()` at import time, which sets
|
||||||
|
`propagate=False` on the `axolotl` logger so records do not reach the root
|
||||||
|
logger that pytest's `caplog` hooks. This helper temporarily re-enables
|
||||||
|
propagation for the duration of the block.
|
||||||
|
"""
|
||||||
|
ax_logger = logging.getLogger("axolotl")
|
||||||
|
old_propagate = ax_logger.propagate
|
||||||
|
ax_logger.propagate = True
|
||||||
|
try:
|
||||||
|
with caplog.at_level(logging.WARNING, logger="axolotl"):
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
ax_logger.propagate = old_propagate
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizerLegacyMapping:
|
class TestNormalizerLegacyMapping:
|
||||||
"""Legacy boolean flags map to canonical attn_implementation."""
|
"""Legacy boolean flags map to canonical attn_implementation."""
|
||||||
|
|
||||||
@@ -361,41 +383,35 @@ class TestPhase2GapFixes:
|
|||||||
"""Regression tests for the validator gaps closed in Phase 2."""
|
"""Regression tests for the validator gaps closed in Phase 2."""
|
||||||
|
|
||||||
def test_sample_packing_with_eager_warns(self, min_base_cfg, caplog):
|
def test_sample_packing_with_eager_warns(self, min_base_cfg, caplog):
|
||||||
import logging
|
|
||||||
|
|
||||||
cfg = min_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
attn_implementation="eager", sample_packing=True
|
attn_implementation="eager", sample_packing=True
|
||||||
)
|
)
|
||||||
with caplog.at_level(logging.WARNING):
|
with _capture_axolotl_warnings(caplog):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert any(
|
assert any(
|
||||||
"does not handle cross-sample decontamination" in r.message
|
"does not handle cross-sample decontamination" in r.getMessage()
|
||||||
for r in caplog.records
|
for r in caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_sample_packing_with_sdpa_warns(self, min_base_cfg, caplog):
|
def test_sample_packing_with_sdpa_warns(self, min_base_cfg, caplog):
|
||||||
import logging
|
|
||||||
|
|
||||||
cfg = min_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
attn_implementation="sdpa", sample_packing=True
|
attn_implementation="sdpa", sample_packing=True
|
||||||
)
|
)
|
||||||
with caplog.at_level(logging.WARNING):
|
with _capture_axolotl_warnings(caplog):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert any(
|
assert any(
|
||||||
"does not handle cross-sample decontamination" in r.message
|
"does not handle cross-sample decontamination" in r.getMessage()
|
||||||
for r in caplog.records
|
for r in caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_sample_packing_with_flash_does_not_warn(self, min_base_cfg, caplog):
|
def test_sample_packing_with_flash_does_not_warn(self, min_base_cfg, caplog):
|
||||||
import logging
|
|
||||||
|
|
||||||
cfg = min_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
attn_implementation="flash_attention_2", sample_packing=True
|
attn_implementation="flash_attention_2", sample_packing=True
|
||||||
)
|
)
|
||||||
with caplog.at_level(logging.WARNING):
|
with _capture_axolotl_warnings(caplog):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert not any(
|
assert not any(
|
||||||
"does not handle cross-sample decontamination" in r.message
|
"does not handle cross-sample decontamination" in r.getMessage()
|
||||||
for r in caplog.records
|
for r in caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user