unset fa
This commit is contained in:
@@ -14,8 +14,6 @@ from huggingface_hub import snapshot_download
|
||||
from packaging import version
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
from axolotl.cli.preprocess import do_preprocess
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||
@@ -50,8 +48,11 @@ def sft_base_cfg():
|
||||
flash_attention=True,
|
||||
learning_rate=0.00001,
|
||||
optimizer="adamw_8bit",
|
||||
# these need to be set since we aren't running schema validation
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -61,7 +62,21 @@ def sft_prepared_dataset_alpaca_cfg(module_temp_dir, sft_base_cfg):
|
||||
cfg = sft_base_cfg | DictDefault(
|
||||
dataset_prepared_path=dataset_prepared_path,
|
||||
)
|
||||
do_preprocess(cfg, PreprocessCliArgs())
|
||||
|
||||
Path(module_temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(module_temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"preprocess",
|
||||
str(Path(module_temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
# unset flash attention since we have some flex attention tests too
|
||||
cfg.flash_attention = None
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user