diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5530632be..67b989cb0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,5 +1,7 @@ name: PyTest -on: push +on: + push: + pull_request: jobs: test: diff --git a/requirements.txt b/requirements.txt index 1af103e17..27b31a139 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git bitsandbytes>=0.39.0 -attrdict +addict fire PyYAML==6.0 black diff --git a/scripts/finetune.py b/scripts/finetune.py index 8d7a18a4a..1d1eb9f95 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -10,11 +10,11 @@ from typing import Optional, List, Dict, Any, Union import fire import torch import yaml -from attrdict import AttrDefault # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.validation import validate_config +from axolotl.utils.dict import DictDefault project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -131,10 +131,10 @@ def train( # load the config from the yaml file with open(config, "r") as f: - cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader)) + cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value - cfg_keys = dict(cfg).keys() + cfg_keys = cfg.keys() for k in kwargs: # if not strict, allow writing to cfg even if it's not in the yml already if k in cfg_keys or cfg.strict is False: diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py new file mode 100644 index 000000000..e3a0a517d --- /dev/null +++ b/src/axolotl/utils/dict.py @@ -0,0 +1,10 @@ +from addict import Dict + + +class DictDefault(Dict): + """ + A Dict that returns None instead of returning empty Dict for missing keys. + """ + + def __missing__(self, key): + return None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fe9f18979..774802a7d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -29,7 +29,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: from peft import PeftModel, PeftConfig - from attrdict import AttrDefault + from axolotl.utils.dict import DictDefault from transformers import PreTrainedTokenizer @@ -79,7 +79,7 @@ def load_model( adapter="lora", inference=False, ): - # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit @@ -294,7 +294,7 @@ def load_model( def load_adapter(model, cfg, adapter): - # type: (PreTrainedModel, AttrDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] if adapter is None: return model, None @@ -307,7 +307,7 @@ def load_adapter(model, cfg, adapter): def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import ( AdaptionPromptConfig, get_peft_model, @@ -355,7 +355,7 @@ def find_all_linear_names(bits, model): def load_lora(model, cfg): - # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import ( LoraConfig, diff --git a/tests/test_dict.py b/tests/test_dict.py new file mode 100644 index 000000000..81a528fe4 --- /dev/null +++ b/tests/test_dict.py @@ -0,0 +1,90 @@ +import unittest + +import pytest + +from axolotl.utils.dict import DictDefault + + +class DictDefaultTest(unittest.TestCase): + def test_dict_default(self): + cfg = DictDefault( + { + "key_a": {"key_b": "value_a"}, + "key_c": "value_c", + "key_d": ["value_d", "value_e"], + } + ) + + assert ( + cfg.key_a.key_b == "value_a" + ), "DictDefault should return value for existing nested keys" + + assert ( + cfg.key_c == "value_c" + ), "DictDefault should return value for existing keys" + + assert ( + cfg.key_d[0] == "value_d" + ), "DictDefault should return value for existing keys in list" + + assert ( + "value_e" in cfg.key_d + ), "DictDefault should support in operator for existing keys in list" + + def test_dict_or_operator(self): + cfg = DictDefault( + { + "key_a": {"key_b": "value_a"}, + "key_c": "value_c", + "key_d": ["value_d", "value_e"], + "key_f": "value_f", + } + ) + + cfg = cfg | DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"}) + + assert ( + cfg.key_a.key_b == "value_b" + ), "DictDefault should support OR operator for existing nested keys" + + assert cfg.key_c == "value_c", "DictDefault should not delete existing key" + + assert cfg.key_d == [ + "value_d", + "value_e", + ], "DictDefault should not overwrite existing keys in list" + + assert ( + cfg.key_f == "value_g" + ), "DictDefault should support OR operator for existing key" + + def test_dict_missingkey(self): + cfg = DictDefault({}) + + assert cfg.random_key is None, "DictDefault should return None for missing keys" + + def test_dict_nested_missingparentkey(self): + """ + Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist. + """ + cfg = DictDefault({}) + + with pytest.raises( + AttributeError, + match=r"'NoneType' object has no attribute 'another_random_key'", + ): + cfg.random_key.another_random_key + + def test_dict_shorthand_assignment(self): + """ + Shorthand assignment is said to not be supported if subclassed. However, their example raises error instead of None. + This test ensures that it is supported for current implementation. + + Ref: https://github.com/mewwts/addict#default-values + """ + + cfg = DictDefault({"key_a": {"key_b": "value_a"}}) + + cfg.key_a.key_b = "value_b" + + assert cfg.key_a.key_b == "value_b", "Shorthand assignment should be supported"