diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml index 512784e50..71692958f 100644 --- a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -47,7 +47,6 @@ learning_rate: 2e-5 bf16: true tf32: true -attn_implementation: flash_attention_2 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml index e36cd5192..5912f876b 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml @@ -43,7 +43,6 @@ learning_rate: 2e-5 bf16: true tf32: true -attn_implementation: flash_attention_2 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml index cd85460d8..b1a0fef4a 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -44,7 +44,6 @@ learning_rate: 2e-5 bf16: true tf32: true -attn_implementation: flash_attention_2 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml index 2ebfd1a80..f97174cd9 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -43,7 +43,6 @@ learning_rate: 2e-5 bf16: true tf32: true -attn_implementation: flash_attention_2 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml index dd632e4a0..122fb0b6c 100644 --- a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -56,7 +56,6 @@ learning_rate: 2e-4 bf16: true tf32: true -attn_implementation: flash_attention_2 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml index d57f9501d..7ba5f29b5 100644 --- a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml @@ -56,7 +56,6 @@ learning_rate: 2e-4 bf16: true tf32: true -attn_implementation: flash_attention_2 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py index e0769ffc5..2a817b12c 100644 --- a/tests/test_attn_implementation.py +++ b/tests/test_attn_implementation.py @@ -9,6 +9,9 @@ Covers the Phase 1 contract: Plus Phase 2 gap fixes and full-model validation behaviour. """ +import logging +from contextlib import contextmanager + import pytest from axolotl.utils.config import validate_config @@ -22,6 +25,25 @@ from axolotl.utils.schemas.enums import ( ) +@contextmanager +def _capture_axolotl_warnings(caplog): + """Capture WARNINGs from `axolotl.*` loggers via caplog. + + `axolotl.cli` calls `configure_logging()` at import time, which sets + `propagate=False` on the `axolotl` logger so records do not reach the root + logger that pytest's `caplog` hooks. This helper temporarily re-enables + propagation for the duration of the block. + """ + ax_logger = logging.getLogger("axolotl") + old_propagate = ax_logger.propagate + ax_logger.propagate = True + try: + with caplog.at_level(logging.WARNING, logger="axolotl"): + yield + finally: + ax_logger.propagate = old_propagate + + class TestNormalizerLegacyMapping: """Legacy boolean flags map to canonical attn_implementation.""" @@ -361,41 +383,35 @@ class TestPhase2GapFixes: """Regression tests for the validator gaps closed in Phase 2.""" def test_sample_packing_with_eager_warns(self, min_base_cfg, caplog): - import logging - cfg = min_base_cfg | DictDefault( attn_implementation="eager", sample_packing=True ) - with caplog.at_level(logging.WARNING): + with _capture_axolotl_warnings(caplog): validate_config(cfg) assert any( - "does not handle cross-sample decontamination" in r.message + "does not handle cross-sample decontamination" in r.getMessage() for r in caplog.records ) def test_sample_packing_with_sdpa_warns(self, min_base_cfg, caplog): - import logging - cfg = min_base_cfg | DictDefault( attn_implementation="sdpa", sample_packing=True ) - with caplog.at_level(logging.WARNING): + with _capture_axolotl_warnings(caplog): validate_config(cfg) assert any( - "does not handle cross-sample decontamination" in r.message + "does not handle cross-sample decontamination" in r.getMessage() for r in caplog.records ) def test_sample_packing_with_flash_does_not_warn(self, min_base_cfg, caplog): - import logging - cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", sample_packing=True ) - with caplog.at_level(logging.WARNING): + with _capture_axolotl_warnings(caplog): validate_config(cfg) assert not any( - "does not handle cross-sample decontamination" in r.message + "does not handle cross-sample decontamination" in r.getMessage() for r in caplog.records )