Ensure device mesh patching is applied (#2842)

* move patches; make patch stronger

* fix broken tests

* guard sequence_parallel_degree comparison against none

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Dan Saunders
2025-06-29 22:16:32 -04:00
committed by GitHub
parent cb811f8bf1
commit 35fdbce102
4 changed files with 52 additions and 19 deletions

View File

@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero():
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)