diff --git a/examples/cerebras/btlm-ft.yml b/examples/cerebras/btlm-ft.yml index 4fd34aa5f..e598fc865 100644 --- a/examples/cerebras/btlm-ft.yml +++ b/examples/cerebras/btlm-ft.yml @@ -1,5 +1,4 @@ base_model: cerebras/btlm-3b-8k-base -base_model_config: cerebras/btlm-3b-8k-base model_type: AutoModelForCausalLM tokenizer_type: GPT2Tokenizer trust_remote_code: true diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml index a13517f3e..352952a54 100644 --- a/examples/cerebras/qlora.yml +++ b/examples/cerebras/qlora.yml @@ -1,5 +1,4 @@ base_model: cerebras/Cerebras-GPT-1.3B -base_model_config: cerebras/Cerebras-GPT-1.3B load_in_8bit: false load_in_4bit: true strict: false diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index 91807846b..2909e2477 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -1,5 +1,4 @@ base_model: codellama/CodeLlama-13b-hf -base_model_config: codellama/CodeLlama-13b-hf model_type: LlamaForCausalLM tokenizer_type: CodeLlamaTokenizer is_llama_derived_model: true diff --git a/examples/code-llama/13b/qlora.yml b/examples/code-llama/13b/qlora.yml index 9fa05ffab..dff95a5aa 100644 --- a/examples/code-llama/13b/qlora.yml +++ b/examples/code-llama/13b/qlora.yml @@ -1,5 +1,4 @@ base_model: codellama/CodeLlama-13b-hf -base_model_config: codellama/CodeLlama-13b-hf model_type: LlamaForCausalLM tokenizer_type: CodeLlamaTokenizer is_llama_derived_model: true diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index a342b6ebc..5601b2e0b 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -1,5 +1,4 @@ base_model: codellama/CodeLlama-34b-hf -base_model_config: codellama/CodeLlama-34b-hf model_type: LlamaForCausalLM tokenizer_type: CodeLlamaTokenizer is_llama_derived_model: true diff --git a/examples/code-llama/34b/qlora.yml b/examples/code-llama/34b/qlora.yml index 1501dd9a3..71a39e534 100644 --- a/examples/code-llama/34b/qlora.yml +++ b/examples/code-llama/34b/qlora.yml @@ -1,5 +1,4 @@ base_model: codellama/CodeLlama-34b-hf -base_model_config: codellama/CodeLlama-34b-hf model_type: LlamaForCausalLM tokenizer_type: CodeLlamaTokenizer is_llama_derived_model: true diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index 638dddc43..345745681 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -1,5 +1,4 @@ base_model: codellama/CodeLlama-7b-hf -base_model_config: codellama/CodeLlama-7b-hf model_type: LlamaForCausalLM tokenizer_type: CodeLlamaTokenizer is_llama_derived_model: true diff --git a/examples/code-llama/7b/qlora.yml b/examples/code-llama/7b/qlora.yml index 5b3b33822..25357ad28 100644 --- a/examples/code-llama/7b/qlora.yml +++ b/examples/code-llama/7b/qlora.yml @@ -1,5 +1,4 @@ base_model: codellama/CodeLlama-7b-hf -base_model_config: codellama/CodeLlama-7b-hf model_type: LlamaForCausalLM tokenizer_type: CodeLlamaTokenizer is_llama_derived_model: true diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index f45deb643..d46702152 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -1,5 +1,4 @@ base_model: tiiuae/falcon-7b -base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index f59341965..3c201ff5f 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -1,7 +1,6 @@ # 1b: tiiuae/falcon-rw-1b # 40b: tiiuae/falcon-40b base_model: tiiuae/falcon-7b -base_model_config: tiiuae/falcon-7b # required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main trust_remote_code: true model_type: AutoModelForCausalLM diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 777a97b31..96039db60 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -1,5 +1,4 @@ base_model: tiiuae/falcon-7b -base_model_config: tiiuae/falcon-7b trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml index 696747dfe..b9455624a 100644 --- a/examples/gptj/qlora.yml +++ b/examples/gptj/qlora.yml @@ -1,5 +1,4 @@ base_model: EleutherAI/gpt-j-6b -base_model_config: EleutherAI/gpt-j-6b load_in_8bit: false load_in_4bit: true strict: false diff --git a/examples/jeopardy-bot/config.yml b/examples/jeopardy-bot/config.yml index 32e7a34ee..710a74fdf 100644 --- a/examples/jeopardy-bot/config.yml +++ b/examples/jeopardy-bot/config.yml @@ -1,5 +1,4 @@ base_model: huggyllama/llama-7b -base_model_config: huggyllama/llama-7b model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer load_in_8bit: false diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index a96c1cfb8..e7bbfc1c9 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -1,5 +1,4 @@ base_model: NousResearch/Llama-2-7b-hf -base_model_config: NousResearch/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer is_llama_derived_model: true diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 257433f26..759b304d8 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -1,5 +1,4 @@ base_model: TheBloke/Llama-2-7B-GPTQ -base_model_config: TheBloke/Llama-2-7B-GPTQ is_llama_derived_model: false gptq: true gptq_disable_exllama: true diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 8c0e3e910..5afe3d7d1 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -1,5 +1,4 @@ base_model: NousResearch/Llama-2-7b-hf -base_model_config: NousResearch/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer is_llama_derived_model: true diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index b8209934c..447761f7e 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -1,5 +1,4 @@ base_model: NousResearch/Llama-2-7b-hf -base_model_config: NousResearch/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer is_llama_derived_model: true diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index 9f27cafea..2e6923811 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -1,5 +1,4 @@ base_model: NousResearch/Llama-2-7b-hf -base_model_config: NousResearch/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer is_llama_derived_model: true diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml index 0b56ea7d3..af05830aa 100644 --- a/examples/llama-2/tiny-llama.yml +++ b/examples/llama-2/tiny-llama.yml @@ -1,5 +1,4 @@ base_model: PY007/TinyLlama-1.1B-step-50K-105b -base_model_config: PY007/TinyLlama-1.1B-step-50K-105b model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 2a4498a11..67a663a62 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -1,5 +1,4 @@ base_model: mistralai/Mistral-7B-v0.1 -base_model_config: mistralai/Mistral-7B-v0.1 model_type: MistralForCausalLM tokenizer_type: LlamaTokenizer is_mistral_derived_model: true diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 09639d006..4312a0265 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -1,5 +1,4 @@ base_model: mistralai/Mistral-7B-v0.1 -base_model_config: mistralai/Mistral-7B-v0.1 model_type: MistralForCausalLM tokenizer_type: LlamaTokenizer is_mistral_derived_model: true diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml index 8d9b429b1..7d124e2c0 100644 --- a/examples/mpt-7b/config.yml +++ b/examples/mpt-7b/config.yml @@ -1,5 +1,4 @@ base_model: mosaicml/mpt-7b -base_model_config: mosaicml/mpt-7b tokenizer_type: AutoTokenizer trust_remote_code: true # required for mpt as their model class is not merged into transformers yet load_in_8bit: false diff --git a/examples/openllama-3b/config.yml b/examples/openllama-3b/config.yml index dd11d53b0..df6b26893 100644 --- a/examples/openllama-3b/config.yml +++ b/examples/openllama-3b/config.yml @@ -1,5 +1,4 @@ base_model: openlm-research/open_llama_3b_v2 -base_model_config: openlm-research/open_llama_3b_v2 model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer load_in_8bit: false diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml index fad3fb551..7221abcbd 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/openllama-3b/lora.yml @@ -1,5 +1,4 @@ base_model: openlm-research/open_llama_3b_v2 -base_model_config: openlm-research/open_llama_3b_v2 model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer load_in_8bit: true diff --git a/examples/openllama-3b/qlora.yml b/examples/openllama-3b/qlora.yml index 80d4d727b..9fe5968d2 100644 --- a/examples/openllama-3b/qlora.yml +++ b/examples/openllama-3b/qlora.yml @@ -1,5 +1,4 @@ base_model: openlm-research/open_llama_3b_v2 -base_model_config: openlm-research/open_llama_3b_v2 model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer load_in_8bit: false diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 668eea317..183a715e3 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -1,5 +1,4 @@ base_model: microsoft/phi-1_5 -base_model_config: microsoft/phi-1_5 model_type: MixFormerSequentialForCausalLM tokenizer_type: AutoTokenizer is_llama_derived_model: false diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index a548b3f05..8fe5e98b1 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -1,5 +1,4 @@ base_model: microsoft/phi-1_5 -base_model_config: microsoft/phi-1_5 model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer is_llama_derived_model: false diff --git a/examples/pythia-12b/config.yml b/examples/pythia-12b/config.yml index 4e0e1523a..00693a164 100644 --- a/examples/pythia-12b/config.yml +++ b/examples/pythia-12b/config.yml @@ -1,5 +1,4 @@ base_model: EleutherAI/pythia-12b-deduped -base_model_config: EleutherAI/pythia-12b-deduped base_model_ignore_patterns: pytorch* # prefer safetensors model_type: GPTNeoXForCausalLM tokenizer_type: AutoTokenizer diff --git a/examples/pythia/lora.yml b/examples/pythia/lora.yml index 6ff036621..c256429d8 100644 --- a/examples/pythia/lora.yml +++ b/examples/pythia/lora.yml @@ -1,5 +1,4 @@ base_model: EleutherAI/pythia-1.4b-deduped -base_model_config: EleutherAI/pythia-1.4b-deduped load_in_8bit: true datasets: - path: teknium/GPT4-LLM-Cleaned diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml index 97f31c87a..30c198193 100644 --- a/examples/redpajama/config-3b.yml +++ b/examples/redpajama/config-3b.yml @@ -1,5 +1,4 @@ base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 -base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1 model_type: GPTNeoXForCausalLM tokenizer_type: AutoTokenizer trust_remote_code: diff --git a/examples/replit-3b/config-lora.yml b/examples/replit-3b/config-lora.yml index d345e25a0..cc882c212 100644 --- a/examples/replit-3b/config-lora.yml +++ b/examples/replit-3b/config-lora.yml @@ -1,5 +1,4 @@ base_model: replit/replit-code-v1-3b -base_model_config: replit/replit-code-v1-3b trust_remote_code: true load_in_8bit: false datasets: diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/xgen-7b/xgen-7b-8k-qlora.yml index 352dcb610..f6fced944 100644 --- a/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/xgen-7b/xgen-7b-8k-qlora.yml @@ -1,7 +1,6 @@ # An example finetuning Saleforce's XGen-7b model with 8k context using qlora # on Tim Dettmer's Guanaco dataset. base_model: Salesforce/xgen-7b-8k-base -base_model_config: Salesforce/xgen-7b-8k-base trust_remote_code: true model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9260577db..81660ae65 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -79,6 +79,9 @@ def normalize_config(cfg): cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() + if not cfg.base_model_config: + cfg.base_model_config = cfg.base_model + model_config = load_model_config(cfg) cfg.model_config_type = model_config.model_type diff --git a/tests/e2e/test_fused_llama.py b/tests/e2e/test_fused_llama.py index beb41bdee..9363f333c 100644 --- a/tests/e2e/test_fused_llama.py +++ b/tests/e2e/test_fused_llama.py @@ -31,7 +31,6 @@ class TestFusedLlama(unittest.TestCase): cfg = DictDefault( { "base_model": "JackFram/llama-68m", - "base_model_config": "JackFram/llama-68m", "flash_attention": True, "flash_attn_fuse_qkv": True, "flash_attn_fuse_mlp": True, diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 7d4b75cce..4f50d8194 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -29,7 +29,6 @@ class TestLoraLlama(unittest.TestCase): cfg = DictDefault( { "base_model": "JackFram/llama-68m", - "base_model_config": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", "sequence_len": 1024, "load_in_8bit": True, @@ -72,7 +71,6 @@ class TestLoraLlama(unittest.TestCase): cfg = DictDefault( { "base_model": "JackFram/llama-68m", - "base_model_config": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", "sequence_len": 1024, "sample_packing": True, @@ -117,7 +115,6 @@ class TestLoraLlama(unittest.TestCase): cfg = DictDefault( { "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", - "base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ", "model_type": "AutoModelForCausalLM", "tokenizer_type": "LlamaTokenizer", "sequence_len": 1024, diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index f3098f058..f2928a727 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -31,7 +31,6 @@ class TestMistral(unittest.TestCase): cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", - "base_model_config": "openaccess-ai-collective/tiny-mistral", "flash_attention": True, "sequence_len": 1024, "load_in_8bit": True, @@ -77,7 +76,6 @@ class TestMistral(unittest.TestCase): cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", - "base_model_config": "openaccess-ai-collective/tiny-mistral", "flash_attention": True, "sequence_len": 1024, "val_set_size": 0.1, diff --git a/tests/e2e/test_mistral_samplepack.py b/tests/e2e/test_mistral_samplepack.py index 623d20b0c..5fadf0959 100644 --- a/tests/e2e/test_mistral_samplepack.py +++ b/tests/e2e/test_mistral_samplepack.py @@ -31,7 +31,6 @@ class TestMistral(unittest.TestCase): cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", - "base_model_config": "openaccess-ai-collective/tiny-mistral", "flash_attention": True, "sample_packing": True, "sequence_len": 1024, @@ -78,7 +77,6 @@ class TestMistral(unittest.TestCase): cfg = DictDefault( { "base_model": "openaccess-ai-collective/tiny-mistral", - "base_model_config": "openaccess-ai-collective/tiny-mistral", "flash_attention": True, "sample_packing": True, "sequence_len": 1024, diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index a84ef0778..f9ea52ea2 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -27,7 +27,6 @@ class TestPhi(unittest.TestCase): cfg = DictDefault( { "base_model": "microsoft/phi-1_5", - "base_model_config": "microsoft/phi-1_5", "trust_remote_code": True, "model_type": "MixFormerSequentialForCausalLM", "tokenizer_type": "AutoTokenizer", @@ -71,7 +70,6 @@ class TestPhi(unittest.TestCase): cfg = DictDefault( { "base_model": "microsoft/phi-1_5", - "base_model_config": "microsoft/phi-1_5", "trust_remote_code": True, "model_type": "MixFormerSequentialForCausalLM", "tokenizer_type": "AutoTokenizer", diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 01b8c162c..1397b23af 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -37,3 +37,10 @@ class NormalizeConfigTestCase(unittest.TestCase): normalize_config(cfg) assert cfg.learning_rate == 0.00005 + + def test_base_model_config_set_when_empty(self): + cfg = self._get_base_cfg() + del cfg.base_model_config + normalize_config(cfg) + + assert cfg.base_model_config == cfg.base_model