This commit is contained in:
Wing Lian
2025-07-06 13:27:55 -04:00
parent 47a88da330
commit 5c4705b185

View File

@@ -14,8 +14,6 @@ from huggingface_hub import snapshot_download
from packaging import version
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.cli.args import PreprocessCliArgs
from axolotl.cli.preprocess import do_preprocess
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
@@ -50,8 +48,11 @@ def sft_base_cfg():
flash_attention=True,
learning_rate=0.00001,
optimizer="adamw_8bit",
# these need to be set since we aren't running schema validation
micro_batch_size=2,
gradient_accumulation_steps=1,
)
return cfg
@@ -61,7 +62,21 @@ def sft_prepared_dataset_alpaca_cfg(module_temp_dir, sft_base_cfg):
cfg = sft_base_cfg | DictDefault(
dataset_prepared_path=dataset_prepared_path,
)
do_preprocess(cfg, PreprocessCliArgs())
Path(module_temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(module_temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"preprocess",
str(Path(module_temp_dir) / "config.yaml"),
]
)
# unset flash attention since we have some flex attention tests too
cfg.flash_attention = None
return cfg