From 56f9ca57098bb8d4b502f48ab1516711e607a368 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 22:25:42 +0900 Subject: [PATCH] refactor: fix previous refactors --- scripts/finetune.py | 2 +- src/axolotl/utils/dict.py | 2 +- src/axolotl/utils/models.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index b25412e7f..1d1eb9f95 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -83,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): temperature=0.9, top_p=0.95, top_k=40, - return_DictDefault_in_generate=True, + return_dict_in_generate=True, output_attentions=False, output_hidden_states=False, output_scores=False, diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index f7297efb2..003a9fa9e 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -6,4 +6,4 @@ class DictDefault(Dict): A Dict that returns None instead of returning empty Dict for missing keys. ''' def __missing__(self, key): - return None \ No newline at end of file + return None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 80e2d2447..774802a7d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -184,9 +184,9 @@ def load_model( # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 - # from flash_attn.utils.pretrained import state_DictDefault_from_pretrained + # from flash_attn.utils.pretrained import state_dict_from_pretrained # from flash_attn.models.gpt import GPTLMHeadModel - # from flash_attn.models.gpt_neox import remap_state_DictDefault_hf_gpt_neox, gpt_neox_config_to_gpt2_config + # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config # from transformers import GPTNeoXConfig # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) # config.use_flash_attn = True