From 193c73bce040fe965f5ea66d235e8823bd19e5e7 Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Thu, 8 Jun 2023 09:18:58 +0200 Subject: [PATCH 1/6] Fix training over existing lora When training with Lora, and starting with an existing lora weights, current code produces a model with 0 trainable params and training can't work. Adding the "is_trainable" param allows the loaded peft to be trained and fixes the bug. --- src/axolotl/utils/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 58e0e97ec..b5d5124cb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -402,6 +402,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, + is_trainable=True, device_map=cfg.device_map, # torch_dtype=torch.float16, ) From 813cfa4c14f990c53ed42e9decd84b3e41a91102 Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Fri, 9 Jun 2023 08:49:32 +0200 Subject: [PATCH 2/6] WIP: Rely on cfg.inference --- src/axolotl/utils/models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b5d5124cb..c3f988e52 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -80,8 +80,7 @@ def load_model( model_type, tokenizer, cfg, - adapter="lora", - inference=False, + adapter="lora" ): # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ @@ -95,7 +94,7 @@ def load_model( ) if is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: + if cfg.device not in ["mps", "cpu"] and cfg.inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn logging.info("patching with flash attention") @@ -402,7 +401,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, - is_trainable=True, + is_trainable=not cfg.inference, device_map=cfg.device_map, # torch_dtype=torch.float16, ) From bd3b53734459e0ace9795579150a3eff0ff4eaeb Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Fri, 9 Jun 2023 08:59:05 +0200 Subject: [PATCH 3/6] Feed cfg.inference --- scripts/finetune.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 7c4d865fa..ab8f068aa 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -182,6 +182,9 @@ def train( if cfg.bf16: cfg.fp16 = True cfg.bf16 = False + + # Store inference mode into cfg when passed via args + cfg.inference = True if "inference" in kwargs else cfg.get("inference", False) # load the tokenizer first tokenizer_config = cfg.tokenizer_config or cfg.base_model_config @@ -189,8 +192,8 @@ def train( tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) if check_not_in( - ["inference", "shard", "merge_lora"], kwargs - ): # don't need to load dataset for these + ["shard", "merge_lora"], kwargs + ) and not cfg.inference: # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -216,8 +219,7 @@ def train( cfg.model_type, tokenizer, cfg, - adapter=cfg.adapter, - inference=("inference" in kwargs), + adapter=cfg.adapter ) if "merge_lora" in kwargs and cfg.adapter is not None: @@ -230,7 +232,7 @@ def train( model.save_pretrained(str(Path(cfg.output_dir) / "merged")) return - if "inference" in kwargs: + if cfg.inference: logging.info("calling do_inference function") do_inference(cfg, model, tokenizer) return From c2508987a6354084bf8bbf328c8eccc81e3a9814 Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Sat, 10 Jun 2023 19:06:10 +0200 Subject: [PATCH 4/6] Remove explicit definition of cfg.inference --- scripts/finetune.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index ab8f068aa..faf1bb31d 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -182,9 +182,6 @@ def train( if cfg.bf16: cfg.fp16 = True cfg.bf16 = False - - # Store inference mode into cfg when passed via args - cfg.inference = True if "inference" in kwargs else cfg.get("inference", False) # load the tokenizer first tokenizer_config = cfg.tokenizer_config or cfg.base_model_config From a808bf913f79fa29596e283a0bb70954caac0645 Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Sat, 10 Jun 2023 20:28:49 +0200 Subject: [PATCH 5/6] Fix missing cfg. --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7156adec0..67facd607 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -96,7 +96,7 @@ def load_model( ) if cfg.is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: + if cfg.device not in ["mps", "cpu"] and not cfg.inference: from axolotl.flash_attn import replace_llama_attn_with_flash_attn logging.info("patching with flash attention") From b565ecf0a1d6bcecbcfa7366dc2ca04983ca0523 Mon Sep 17 00:00:00 2001 From: AngainorDev Date: Sun, 11 Jun 2023 15:23:38 +0200 Subject: [PATCH 6/6] Fix strict and Lint --- scripts/finetune.py | 10 +++++----- src/axolotl/utils/models.py | 9 ++------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 3222afd81..49bd505ce 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -158,7 +158,7 @@ def train( cfg_keys = cfg.keys() for k, _ in kwargs.items(): # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or cfg.strict is False: + if k in cfg_keys or not cfg.strict: # handle booleans if isinstance(cfg[k], bool): cfg[k] = bool(kwargs[k]) @@ -198,9 +198,9 @@ def train( logging.info(f"loading tokenizer... {tokenizer_config}") tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) - if check_not_in( - ["shard", "merge_lora"], kwargs - ) and not cfg.inference: # don't need to load dataset for these + if ( + check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference + ): # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -226,7 +226,7 @@ def train( cfg.model_type, tokenizer, cfg, - adapter=cfg.adapter + adapter=cfg.adapter, ) if "merge_lora" in kwargs and cfg.adapter is not None: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 67facd607..3a87392fc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -77,14 +77,9 @@ def load_tokenizer( def load_model( - base_model, - base_model_config, - model_type, - tokenizer, - cfg, - adapter="lora" + base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" ): - # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ Load a model from a base model and a model type. """