From 5c4705b18561c97293b3bcd5d71f70bda1323e9e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 6 Jul 2025 13:27:55 -0400 Subject: [PATCH] unset fa --- tests/e2e/multigpu/test_llama.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 7341c6dc4..5da2fefd3 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -14,8 +14,6 @@ from huggingface_hub import snapshot_download from packaging import version from transformers.testing_utils import get_torch_dist_unique_port -from axolotl.cli.args import PreprocessCliArgs -from axolotl.cli.preprocess import do_preprocess from axolotl.utils.dict import DictDefault from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 @@ -50,8 +48,11 @@ def sft_base_cfg(): flash_attention=True, learning_rate=0.00001, optimizer="adamw_8bit", + # these need to be set since we aren't running schema validation micro_batch_size=2, + gradient_accumulation_steps=1, ) + return cfg @@ -61,7 +62,21 @@ def sft_prepared_dataset_alpaca_cfg(module_temp_dir, sft_base_cfg): cfg = sft_base_cfg | DictDefault( dataset_prepared_path=dataset_prepared_path, ) - do_preprocess(cfg, PreprocessCliArgs()) + + Path(module_temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(module_temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "preprocess", + str(Path(module_temp_dir) / "config.yaml"), + ] + ) + + # unset flash attention since we have some flex attention tests too + cfg.flash_attention = None return cfg