From bdfe7c9201eeb1ae159fc60560534f5e6b817566 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 02:28:06 +0900 Subject: [PATCH 1/9] Convert attrdict to addict --- requirements.txt | 2 +- scripts/finetune.py | 4 ++-- src/axolotl/utils/models.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1af103e17..27b31a139 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git bitsandbytes>=0.39.0 -attrdict +addict fire PyYAML==6.0 black diff --git a/scripts/finetune.py b/scripts/finetune.py index 8d7a18a4a..954ce1625 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -10,7 +10,7 @@ from typing import Optional, List, Dict, Any, Union import fire import torch import yaml -from attrdict import AttrDefault +from addict import Dict # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels @@ -131,7 +131,7 @@ def train( # load the config from the yaml file with open(config, "r") as f: - cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader)) + cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = dict(cfg).keys() diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fe9f18979..6538086fb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -29,7 +29,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: from peft import PeftModel, PeftConfig - from attrdict import AttrDefault + from addict import Dict from transformers import PreTrainedTokenizer @@ -79,7 +79,7 @@ def load_model( adapter="lora", inference=False, ): - # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (str, str, str, str, Dict, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit @@ -294,7 +294,7 @@ def load_model( def load_adapter(model, cfg, adapter): - # type: (PreTrainedModel, AttrDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, Dict, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] if adapter is None: return model, None @@ -307,7 +307,7 @@ def load_adapter(model, cfg, adapter): def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, Dict) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import ( AdaptionPromptConfig, get_peft_model, @@ -355,7 +355,7 @@ def find_all_linear_names(bits, model): def load_lora(model, cfg): - # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, Dict) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import ( LoraConfig, From 93acb648bda4a138dcc3c3187d0355d395ed3c6a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 02:55:46 +0900 Subject: [PATCH 2/9] Fix load error --- scripts/finetune.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 954ce1625..57d2520d3 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -131,10 +131,10 @@ def train( # load the config from the yaml file with open(config, "r") as f: - cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader)) + cfg: Dict = Dict(yaml.load(f, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value - cfg_keys = dict(cfg).keys() + cfg_keys = cfg.keys() for k in kwargs: # if not strict, allow writing to cfg even if it's not in the yml already if k in cfg_keys or cfg.strict is False: From 18d41cee4a80fbf3aa029f1ef90b4b6f2fdb41cd Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 22:08:39 +0900 Subject: [PATCH 3/9] Add DictDefault --- src/axolotl/utils/dict.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 src/axolotl/utils/dict.py diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py new file mode 100644 index 000000000..f7297efb2 --- /dev/null +++ b/src/axolotl/utils/dict.py @@ -0,0 +1,9 @@ +from addict import Dict + + +class DictDefault(Dict): + ''' + A Dict that returns None instead of returning empty Dict for missing keys. + ''' + def __missing__(self, key): + return None \ No newline at end of file From 8bd7a49cd7be78e370db45df8c985a30754001a6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 22:09:04 +0900 Subject: [PATCH 4/9] Refactor to use DictDefault instead --- scripts/finetune.py | 6 +++--- src/axolotl/utils/models.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 57d2520d3..b25412e7f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -10,11 +10,11 @@ from typing import Optional, List, Dict, Any, Union import fire import torch import yaml -from addict import Dict # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.validation import validate_config +from axolotl.utils.dict import DictDefault project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -83,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): temperature=0.9, top_p=0.95, top_k=40, - return_dict_in_generate=True, + return_DictDefault_in_generate=True, output_attentions=False, output_hidden_states=False, output_scores=False, @@ -131,7 +131,7 @@ def train( # load the config from the yaml file with open(config, "r") as f: - cfg: Dict = Dict(yaml.load(f, Loader=yaml.Loader)) + cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6538086fb..80e2d2447 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -29,7 +29,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: from peft import PeftModel, PeftConfig - from addict import Dict + from axolotl.utils.dict import DictDefault from transformers import PreTrainedTokenizer @@ -79,7 +79,7 @@ def load_model( adapter="lora", inference=False, ): - # type: (str, str, str, str, Dict, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] + # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit @@ -184,9 +184,9 @@ def load_model( # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 - # from flash_attn.utils.pretrained import state_dict_from_pretrained + # from flash_attn.utils.pretrained import state_DictDefault_from_pretrained # from flash_attn.models.gpt import GPTLMHeadModel - # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config + # from flash_attn.models.gpt_neox import remap_state_DictDefault_hf_gpt_neox, gpt_neox_config_to_gpt2_config # from transformers import GPTNeoXConfig # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) # config.use_flash_attn = True @@ -294,7 +294,7 @@ def load_model( def load_adapter(model, cfg, adapter): - # type: (PreTrainedModel, Dict, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] if adapter is None: return model, None @@ -307,7 +307,7 @@ def load_adapter(model, cfg, adapter): def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, Dict) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import ( AdaptionPromptConfig, get_peft_model, @@ -355,7 +355,7 @@ def find_all_linear_names(bits, model): def load_lora(model, cfg): - # type: (PreTrainedModel, Dict) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import ( LoraConfig, From 56f9ca57098bb8d4b502f48ab1516711e607a368 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 22:25:42 +0900 Subject: [PATCH 5/9] refactor: fix previous refactors --- scripts/finetune.py | 2 +- src/axolotl/utils/dict.py | 2 +- src/axolotl/utils/models.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index b25412e7f..1d1eb9f95 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -83,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): temperature=0.9, top_p=0.95, top_k=40, - return_DictDefault_in_generate=True, + return_dict_in_generate=True, output_attentions=False, output_hidden_states=False, output_scores=False, diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index f7297efb2..003a9fa9e 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -6,4 +6,4 @@ class DictDefault(Dict): A Dict that returns None instead of returning empty Dict for missing keys. ''' def __missing__(self, key): - return None \ No newline at end of file + return None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 80e2d2447..774802a7d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -184,9 +184,9 @@ def load_model( # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 - # from flash_attn.utils.pretrained import state_DictDefault_from_pretrained + # from flash_attn.utils.pretrained import state_dict_from_pretrained # from flash_attn.models.gpt import GPTLMHeadModel - # from flash_attn.models.gpt_neox import remap_state_DictDefault_hf_gpt_neox, gpt_neox_config_to_gpt2_config + # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config # from transformers import GPTNeoXConfig # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) # config.use_flash_attn = True From 923151ffab445d8662a6d5d9012f371c2dd3b948 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 23:05:09 +0900 Subject: [PATCH 6/9] Add test for DictDefault --- tests/test_dict.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/test_dict.py diff --git a/tests/test_dict.py b/tests/test_dict.py new file mode 100644 index 000000000..aea932a16 --- /dev/null +++ b/tests/test_dict.py @@ -0,0 +1,91 @@ +import unittest + +import pytest + +from axolotl.utils.dict import DictDefault + + +class DictDefaultTest(unittest.TestCase): + def test_dict_default(self): + cfg = DictDefault( + { + "key_a": {"key_b": "value_a"}, + "key_c": "value_c", + "key_d": ["value_d", "value_e"], + } + ) + + assert ( + cfg.key_a.key_b == "value_a" + ), "DictDefault should return value for existing nested keys" + + assert ( + cfg.key_c == "value_c" + ), "DictDefault should return value for existing keys" + + assert ( + cfg.key_d[0] == "value_d" + ), "DictDefault should return value for existing keys in list" + + assert ( + "value_e" in cfg.key_d, + "DictDefault should support in operator for existing keys in list", + ) + + def test_dict_or_operator(self): + cfg = DictDefault( + { + "key_a": {"key_b": "value_a"}, + "key_c": "value_c", + "key_d": ["value_d", "value_e"], + "key_f": "value_f", + } + ) + + cfg = cfg | DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"}) + + assert ( + cfg.key_a.key_b == "value_b" + ), "DictDefault should support OR operator for existing nested keys" + + assert cfg.key_c == "value_c", "DictDefault should not delete existing key" + + assert cfg.key_d == [ + "value_d", + "value_e", + ], "DictDefault should not overwrite existing keys in list" + + assert ( + cfg.key_f == "value_g" + ), "DictDefault should support OR operator for existing key" + + def test_dict_missingkey(self): + cfg = DictDefault({}) + + assert cfg.random_key is None, "DictDefault should return None for missing keys" + + def test_dict_nested_missingparentkey(self): + """ + Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist. + """ + cfg = DictDefault({}) + + with pytest.raises( + AttributeError, + match=r"'NoneType' object has no attribute 'another_random_key'", + ): + cfg.random_key.another_random_key + + def test_dict_shorthand_assignment(self): + """ + Shorthand assignment is said to not be supported if subclassed. However, their example raises error instead of None. + This test ensures that it is supported for current implementation. + + Ref: https://github.com/mewwts/addict#default-values + """ + + cfg = DictDefault({"key_a": {"key_b": "value_a"}}) + + cfg.key_a.key_b = "value_b" + + assert cfg.key_a.key_b == "value_b", "Shorthand assignment should be supported" From 7bf2069afd349e1e049b2086c5cbc02b5ab0430e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 23:14:04 +0900 Subject: [PATCH 7/9] Apply black formatter --- src/axolotl/utils/dict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index 003a9fa9e..e3a0a517d 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -2,8 +2,9 @@ from addict import Dict class DictDefault(Dict): - ''' + """ A Dict that returns None instead of returning empty Dict for missing keys. - ''' + """ + def __missing__(self, key): return None From dd83a20c27bf0ad2c8e34bff85368c88531a7bf7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 23:30:17 +0900 Subject: [PATCH 8/9] Update test to run on PR --- .github/workflows/tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5530632be..67b989cb0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,5 +1,7 @@ name: PyTest -on: push +on: + push: + pull_request: jobs: test: From f87bd205556657725a1433fc8cdd56dd79e1e1ca Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 28 May 2023 23:35:29 +0900 Subject: [PATCH 9/9] Fix incorrect syntax in test --- tests/test_dict.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_dict.py b/tests/test_dict.py index aea932a16..81a528fe4 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -28,9 +28,8 @@ class DictDefaultTest(unittest.TestCase): ), "DictDefault should return value for existing keys in list" assert ( - "value_e" in cfg.key_d, - "DictDefault should support in operator for existing keys in list", - ) + "value_e" in cfg.key_d + ), "DictDefault should support in operator for existing keys in list" def test_dict_or_operator(self): cfg = DictDefault(