add tests for merging lora and validating the dtype

This commit is contained in:
Wing Lian
2024-04-10 13:00:37 -04:00
parent 5ed29393e3
commit 5767eea874
2 changed files with 86 additions and 17 deletions

View File

@@ -8,6 +8,7 @@ import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -27,21 +28,26 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False, flash_attention=False,
**kwargs, **kwargs,
) )
cfg = modify_cfg_for_merge(parsed_cfg)
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir: do_merge_lora(cfg=cfg, cli_args=parsed_cli_args)
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
raise ValueError( raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist." f"Target directory for merge: `{cfg.lora_model_dir}` does not exist."
) )
parsed_cfg.load_in_4bit = False cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False cfg.load_in_8bit = False
parsed_cfg.flash_attention = False cfg.flash_attention = False
parsed_cfg.deepspeed = None cfg.deepspeed = None
parsed_cfg.fsdp = None cfg.fsdp = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) return cfg
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,13 +1,16 @@
""" """
E2E tests for lora llama E2E tests for lora llama
""" """
import json
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from axolotl.cli import load_datasets from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
@@ -39,11 +42,6 @@ class TestLoraLlama(unittest.TestCase):
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",
@@ -57,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 10,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -65,3 +64,67 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
cli_args = TrainerCliArgs(merge_lora=True)
do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()
with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "torch.bfloat16"
else:
assert config["torch_dtype"] == "torch.float16"