diff --git a/.gitignore b/.gitignore
index b7a09516c..93a4f81b5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,163 @@
**/axolotl.egg-info
-**/__pycache__
-.idea
configs
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
\ No newline at end of file
diff --git a/README.md b/README.md
index f3985fea8..f79a49a1f 100644
--- a/README.md
+++ b/README.md
@@ -97,6 +97,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "...", "reflection": "...", "corrected": "..."}
```
+- `explainchoice`: question, choices, (solution OR explanation)
+ ```json
+ {"question": "...", "choices": ["..."], "solution": "...", "explanation": "..."}
+ ```
+- `concisechoice`: question, choices, (solution OR explanation)
+ ```json
+ {"question": "...", "choices": ["..."], "solution": "...", "explanation": "..."}
+ ```
+- `summarizetldr`: article and summary
+ ```json
+ {"article": "...", "summary": "..."}
+ ```
> Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch!
@@ -124,17 +136,17 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
- loading
```yaml
- load_4bit: true
+ load_in_4bit: true
load_in_8bit: true
- bf16: true
+ bf16: true # require >=ampere
fp16: true
- tf32: true
+ tf32: true # require >=ampere
```
Note: Repo does not do 4-bit quantization.
- lora
```yaml
- adapter: lora # blank for full finetune
+ adapter: lora # qlora or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
@@ -163,28 +175,32 @@ tokenizer_type: AutoTokenizer
# Trust remote code for untrusted source
trust_remote_code:
-# whether you are training a 4-bit quantized model
+# whether you are training a 4-bit GPTQ quantized model
load_4bit: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
+# use bitsandbytes 4 bit
+load_in_4bit:
# Use CUDA bf16
-bf16: true
+bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
# Use CUDA fp16
fp16: true
# Use CUDA tf32
-tf32: true
+tf32: true # require >=ampere
# a list of one or more datasets to finetune the model with
datasets:
# this can be either a hf dataset, or relative path
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
- type: alpaca
+ type: alpaca # format OR format:prompt_style (chat/instruct)
data_files: # path to source data files
+ shards: # true if use subset data. make sure to set `shards` param also
+shards: # number of shards to split dataset into
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
@@ -201,7 +217,7 @@ sequence_len: 2048
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
max_packed_sequence_len: 1024
-# if you want to use lora, leave blank to train all parameters in original model
+# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# if you already have a lora model trained that you want to load, put that here
# lora hyperparameters
@@ -224,6 +240,7 @@ lora_out_dir:
lora_fan_in_fan_out: false
# wandb configuration if you're using it
+wandb_mode:
wandb_project:
wandb_watch:
wandb_run_id:
@@ -252,8 +269,18 @@ gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
-# specify a scheduler to use with the optimizer. only one_cycle is supported currently
-lr_scheduler:
+
+# specify a scheduler and kwargs to use with the optimizer
+lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
+lr_scheduler_kwargs:
+
+# for one_cycle optim
+lr_div_factor: # learning rate div factor
+
+# for log_sweep optim
+log_sweep_min_lr:
+log_sweep_max_lr:
+
# specify optimizer
optimizer:
# specify weight decay
@@ -262,7 +289,7 @@ weight_decay:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
-flash_attention:
+flash_attention: # require a100 for llama
# resume from a specific checkpoint dir
resume_from_checkpoint:
@@ -288,11 +315,17 @@ fsdp_config:
# Deepspeed
deepspeed:
-# TODO
+# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
+# Set padding for data collator to 'longest'
+collator_pad_to_longest:
+
# Debug mode
debug:
+
+# Seed
+seed:
```
@@ -317,12 +350,16 @@ accelerate launch scripts/finetune.py configs/your_config.yml
### Inference
-Add `--inference` flag to train command above
+Pass the appropriate flag to the train command:
-If you are inferencing a pretrained LORA, pass
-```bash
---lora_model_dir ./completed-model
-```
+- Pretrained LORA:
+ ```bash
+ --inference --lora_model_dir ./completed-model
+ ```
+- Full weights finetune:
+ ```bash
+ --inference --base_model ./completed-model
+ ```
### Merge LORA to base
@@ -341,8 +378,11 @@ Please reduce any below
- `eval_batch_size`
- `sequence_len`
-
-## Need help
+> RuntimeError: expected scalar type Float but found Half
+
+Try set `fp16: true`
+
+## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base
index 943bae3b0..c63f5a496 100644
--- a/docker/Dockerfile-base
+++ b/docker/Dockerfile-base
@@ -43,11 +43,11 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \
- cd csrc/xentropy && \
+ cd ../xentropy && \
python3 setup.py bdist_wheel && \
- cd csrc/rotary && \
+ cd ../rotary && \
python3 setup.py bdist_wheel && \
- cd csrc/layer_norm && \
+ cd ../layer_norm && \
python3 setup.py bdist_wheel
FROM base-builder AS deepspeed-builder
diff --git a/examples/lora-openllama-3b/config.yml b/examples/lora-openllama-3b/config.yml
new file mode 100644
index 000000000..6665044e0
--- /dev/null
+++ b/examples/lora-openllama-3b/config.yml
@@ -0,0 +1,67 @@
+base_model: openlm-research/open_llama_3b_600bt_preview
+base_model_config: openlm-research/open_llama_3b_600bt_preview
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+push_dataset_to_hub:
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.02
+adapter: lora
+lora_model_dir:
+sequence_len: 256
+max_packed_sequence_len:
+lora_r: 8
+lora_alpha: 16
+lora_dropout: 0.0
+lora_target_modules:
+ - gate_proj
+ - down_proj
+ - up_proj
+ - q_proj
+ - v_proj
+ - k_proj
+ - o_proj
+lora_fan_in_fan_out:
+wandb_project:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+output_dir: ./lora-out
+batch_size: 16
+micro_batch_size: 4
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+torchdistx_path:
+lr_scheduler: cosine
+learning_rate: 0.0002
+train_on_inputs: false
+group_by_length: false
+bf16: false
+fp16: true
+tf32: false
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 10
+eval_steps: 50
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index a6d237a11..fd9dfc8d4 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -17,8 +17,8 @@ class AlpacaPrompter:
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
prompt_style = None
- def __init__(self, prompt_style="instruct"):
- self.prompt_style = prompt_style
+ def __init__(self, prompt_style=PromptStyle.instruct.value):
+ self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
self.match_prompt_style()
def match_prompt_style(self):
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 5b243bec4..de04e9333 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -211,12 +211,12 @@ def load_model(
try:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(
- model,
+ base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
- model,
+ base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except:
diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py
new file mode 100644
index 000000000..9bef37406
--- /dev/null
+++ b/src/axolotl/utils/validation.py
@@ -0,0 +1,10 @@
+def validate_config(cfg):
+ if cfg.adapter == "qlora":
+ assert cfg.load_in_8bit is False
+ assert cfg.load_4bit is False
+ assert cfg.load_in_4bit is True
+ pass
+ # TODO
+ # MPT 7b
+ # https://github.com/facebookresearch/bitsandbytes/issues/25
+ # no 8bit adamw w bf16