Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
4c92b51cd5 fix the torch dtype check 2024-04-11 08:56:46 -04:00
Wing Lian
5767eea874 add tests for merging lora and validating the dtype 2024-04-10 13:00:37 -04:00
2 changed files with 86 additions and 17 deletions

View File

@@ -8,6 +8,7 @@ import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -27,21 +28,26 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False, flash_attention=False,
**kwargs, **kwargs,
) )
cfg = modify_cfg_for_merge(parsed_cfg)
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir: do_merge_lora(cfg=cfg, cli_args=parsed_cli_args)
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
raise ValueError( raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist." f"Target directory for merge: `{cfg.lora_model_dir}` does not exist."
) )
parsed_cfg.load_in_4bit = False cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False cfg.load_in_8bit = False
parsed_cfg.flash_attention = False cfg.flash_attention = False
parsed_cfg.deepspeed = None cfg.deepspeed = None
parsed_cfg.fsdp = None cfg.fsdp = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) return cfg
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,13 +1,16 @@
""" """
E2E tests for lora llama E2E tests for lora llama
""" """
import json
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from axolotl.cli import load_datasets from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
@@ -39,11 +42,6 @@ class TestLoraLlama(unittest.TestCase):
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",
@@ -57,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 10,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -65,3 +64,67 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
cli_args = TrainerCliArgs(merge_lora=True)
do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()
with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "float16"