Fix for check with cfg and merge_lora (#600)

This commit is contained in:
Wing Lian
2023-09-18 21:14:32 -04:00
committed by GitHub
parent 31b9e0c6e8
commit 62a774140b
3 changed files with 3 additions and 3 deletions

View File

@@ -61,7 +61,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
# cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |

View File

@@ -70,7 +70,7 @@ def do_merge_lora(
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,

View File

@@ -13,12 +13,12 @@ from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)