Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
3202f19f52 add save_only_model arg 2024-04-10 16:09:08 -04:00
4 changed files with 21 additions and 86 deletions

View File

@@ -8,7 +8,6 @@ import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -28,26 +27,21 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False, flash_attention=False,
**kwargs, **kwargs,
) )
cfg = modify_cfg_for_merge(parsed_cfg)
do_merge_lora(cfg=cfg, cli_args=parsed_cli_args) if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
raise ValueError( raise ValueError(
f"Target directory for merge: `{cfg.lora_model_dir}` does not exist." f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
) )
cfg.load_in_4bit = False parsed_cfg.load_in_4bit = False
cfg.load_in_8bit = False parsed_cfg.load_in_8bit = False
cfg.flash_attention = False parsed_cfg.flash_attention = False
cfg.deepspeed = None parsed_cfg.deepspeed = None
cfg.fsdp = None parsed_cfg.fsdp = None
return cfg do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1058,6 +1058,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.save_safetensors is not None: if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.save_only_model is not None:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_efficiency" "sample_packing_efficiency"

View File

@@ -355,6 +355,7 @@ class ModelOutputConfig(BaseModel):
hub_model_id: Optional[str] = None hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = None save_safetensors: Optional[bool] = None
save_only_model: Optional[bool] = None
class MLFlowConfig(BaseModel): class MLFlowConfig(BaseModel):

View File

@@ -1,16 +1,13 @@
""" """
E2E tests for lora llama E2E tests for lora llama
""" """
import json
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli import load_datasets
from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
@@ -42,6 +39,11 @@ class TestLoraLlama(unittest.TestCase):
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",
@@ -55,7 +57,6 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 10,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -64,67 +65,3 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
cli_args = TrainerCliArgs(merge_lora=True)
do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()
with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "float16"