unset fa
This commit is contained in:
@@ -14,8 +14,6 @@ from huggingface_hub import snapshot_download
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
|
||||||
from axolotl.cli.preprocess import do_preprocess
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||||
@@ -50,8 +48,11 @@ def sft_base_cfg():
|
|||||||
flash_attention=True,
|
flash_attention=True,
|
||||||
learning_rate=0.00001,
|
learning_rate=0.00001,
|
||||||
optimizer="adamw_8bit",
|
optimizer="adamw_8bit",
|
||||||
|
# these need to be set since we aren't running schema validation
|
||||||
micro_batch_size=2,
|
micro_batch_size=2,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
@@ -61,7 +62,21 @@ def sft_prepared_dataset_alpaca_cfg(module_temp_dir, sft_base_cfg):
|
|||||||
cfg = sft_base_cfg | DictDefault(
|
cfg = sft_base_cfg | DictDefault(
|
||||||
dataset_prepared_path=dataset_prepared_path,
|
dataset_prepared_path=dataset_prepared_path,
|
||||||
)
|
)
|
||||||
do_preprocess(cfg, PreprocessCliArgs())
|
|
||||||
|
Path(module_temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(module_temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"preprocess",
|
||||||
|
str(Path(module_temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# unset flash attention since we have some flex attention tests too
|
||||||
|
cfg.flash_attention = None
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user