"""E2E smoke test for evaluate CLI command""" import os from pathlib import Path import yaml from accelerate.test_utils import execute_subprocess_async from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault os.environ["WANDB_DISABLED"] = "true" class TestE2eEvaluate: """Test cases for evaluate CLI""" def test_evaluate(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", "sequence_len": 1024, "val_set_size": 0.02, "special_tokens": { "unk_token": "", "bos_token": "", "eos_token": "", }, "datasets": [ { "path": "mhenrichsen/alpaca_2k_test", "type": "alpaca", }, ], "num_epochs": 1, "micro_batch_size": 8, "gradient_accumulation_steps": 1, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 20, } ) # write cfg to yaml file Path(temp_dir).mkdir(parents=True, exist_ok=True) with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) execute_subprocess_async( [ "accelerate", "launch", "--num-processes", "2", "--main_process_port", f"{get_torch_dist_unique_port()}", "-m", "axolotl.cli.evaluate", str(Path(temp_dir) / "config.yaml"), ] )