From 62a774140bb668ba825b9305bc03bef5dd0dcd90 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 18 Sep 2023 21:14:32 -0400 Subject: [PATCH] Fix for check with cfg and merge_lora (#600) --- .github/workflows/tests.yml | 2 +- src/axolotl/cli/__init__.py | 2 +- src/axolotl/cli/merge_lora.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 017c1b1b6..18ff575c1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,7 +61,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.10" - cache: 'pip' # caching pip dependencies +# cache: 'pip' # caching pip dependencies - name: Install dependencies run: | diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index ff8eb3b91..90e1d508b 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -70,7 +70,7 @@ def do_merge_lora( model.to(dtype=torch.float16) if cfg.local_rank == 0: - LOG.info("saving merged model") + LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") model.save_pretrained( str(Path(cfg.output_dir) / "merged"), safe_serialization=safe_serialization, diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 79b7112b5..0caee4c28 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -13,12 +13,12 @@ from axolotl.common.cli import TrainerCliArgs def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) parsed_cli_args.merge_lora = True + parsed_cfg = load_cfg(config, merge_lora=True, **kwargs) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)