diff --git a/.gitignore b/.gitignore index b7a09516c..93a4f81b5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,163 @@ **/axolotl.egg-info -**/__pycache__ -.idea configs + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ \ No newline at end of file diff --git a/README.md b/README.md index f3985fea8..f79a49a1f 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,18 @@ Have dataset(s) in one of the following format (JSONL recommended): ```json {"instruction": "...", "input": "...", "output": "...", "reflection": "...", "corrected": "..."} ``` +- `explainchoice`: question, choices, (solution OR explanation) + ```json + {"question": "...", "choices": ["..."], "solution": "...", "explanation": "..."} + ``` +- `concisechoice`: question, choices, (solution OR explanation) + ```json + {"question": "...", "choices": ["..."], "solution": "...", "explanation": "..."} + ``` +- `summarizetldr`: article and summary + ```json + {"article": "...", "summary": "..."} + ``` > Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch! @@ -124,17 +136,17 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic - loading ```yaml - load_4bit: true + load_in_4bit: true load_in_8bit: true - bf16: true + bf16: true # require >=ampere fp16: true - tf32: true + tf32: true # require >=ampere ``` Note: Repo does not do 4-bit quantization. - lora ```yaml - adapter: lora # blank for full finetune + adapter: lora # qlora or leave blank for full finetune lora_r: 8 lora_alpha: 16 lora_dropout: 0.05 @@ -163,28 +175,32 @@ tokenizer_type: AutoTokenizer # Trust remote code for untrusted source trust_remote_code: -# whether you are training a 4-bit quantized model +# whether you are training a 4-bit GPTQ quantized model load_4bit: true gptq_groupsize: 128 # group size gptq_model_v1: false # v1 or v2 # this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer load_in_8bit: true +# use bitsandbytes 4 bit +load_in_4bit: # Use CUDA bf16 -bf16: true +bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere # Use CUDA fp16 fp16: true # Use CUDA tf32 -tf32: true +tf32: true # require >=ampere # a list of one or more datasets to finetune the model with datasets: # this can be either a hf dataset, or relative path - path: vicgalle/alpaca-gpt4 # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] - type: alpaca + type: alpaca # format OR format:prompt_style (chat/instruct) data_files: # path to source data files + shards: # true if use subset data. make sure to set `shards` param also +shards: # number of shards to split dataset into # axolotl attempts to save the dataset as an arrow after packing the data together so # subsequent training attempts load faster, relative path @@ -201,7 +217,7 @@ sequence_len: 2048 # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning max_packed_sequence_len: 1024 -# if you want to use lora, leave blank to train all parameters in original model +# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model adapter: lora # if you already have a lora model trained that you want to load, put that here # lora hyperparameters @@ -224,6 +240,7 @@ lora_out_dir: lora_fan_in_fan_out: false # wandb configuration if you're using it +wandb_mode: wandb_project: wandb_watch: wandb_run_id: @@ -252,8 +269,18 @@ gradient_checkpointing: false # stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback early_stopping_patience: 3 -# specify a scheduler to use with the optimizer. only one_cycle is supported currently -lr_scheduler: + +# specify a scheduler and kwargs to use with the optimizer +lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine +lr_scheduler_kwargs: + +# for one_cycle optim +lr_div_factor: # learning rate div factor + +# for log_sweep optim +log_sweep_min_lr: +log_sweep_max_lr: + # specify optimizer optimizer: # specify weight decay @@ -262,7 +289,7 @@ weight_decay: # whether to use xformers attention patch https://github.com/facebookresearch/xformers: xformers_attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention: -flash_attention: +flash_attention: # require a100 for llama # resume from a specific checkpoint dir resume_from_checkpoint: @@ -288,11 +315,17 @@ fsdp_config: # Deepspeed deepspeed: -# TODO +# Path to torch distx for optim 'adamw_anyprecision' torchdistx_path: +# Set padding for data collator to 'longest' +collator_pad_to_longest: + # Debug mode debug: + +# Seed +seed: ``` @@ -317,12 +350,16 @@ accelerate launch scripts/finetune.py configs/your_config.yml ### Inference -Add `--inference` flag to train command above +Pass the appropriate flag to the train command: -If you are inferencing a pretrained LORA, pass -```bash ---lora_model_dir ./completed-model -``` +- Pretrained LORA: + ```bash + --inference --lora_model_dir ./completed-model + ``` +- Full weights finetune: + ```bash + --inference --base_model ./completed-model + ``` ### Merge LORA to base @@ -341,8 +378,11 @@ Please reduce any below - `eval_batch_size` - `sequence_len` - -## Need help +> RuntimeError: expected scalar type Float but found Half + +Try set `fp16: true` + +## Need help? 🙋‍♂️ Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 943bae3b0..c63f5a496 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -43,11 +43,11 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \ python3 setup.py bdist_wheel && \ cd csrc/fused_dense_lib && \ python3 setup.py bdist_wheel && \ - cd csrc/xentropy && \ + cd ../xentropy && \ python3 setup.py bdist_wheel && \ - cd csrc/rotary && \ + cd ../rotary && \ python3 setup.py bdist_wheel && \ - cd csrc/layer_norm && \ + cd ../layer_norm && \ python3 setup.py bdist_wheel FROM base-builder AS deepspeed-builder diff --git a/examples/lora-openllama-3b/config.yml b/examples/lora-openllama-3b/config.yml new file mode 100644 index 000000000..6665044e0 --- /dev/null +++ b/examples/lora-openllama-3b/config.yml @@ -0,0 +1,67 @@ +base_model: openlm-research/open_llama_3b_600bt_preview +base_model_config: openlm-research/open_llama_3b_600bt_preview +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +load_in_8bit: true +load_in_4bit: false +strict: false +push_dataset_to_hub: +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +adapter: lora +lora_model_dir: +sequence_len: 256 +max_packed_sequence_len: +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.0 +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj +lora_fan_in_fan_out: +wandb_project: +wandb_watch: +wandb_run_id: +wandb_log_model: +output_dir: ./lora-out +batch_size: 16 +micro_batch_size: 4 +num_epochs: 3 +optimizer: adamw_bnb_8bit +torchdistx_path: +lr_scheduler: cosine +learning_rate: 0.0002 +train_on_inputs: false +group_by_length: false +bf16: false +fp16: true +tf32: false +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: true +flash_attention: +gptq_groupsize: +gptq_model_v1: +warmup_steps: 10 +eval_steps: 50 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index a6d237a11..fd9dfc8d4 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -17,8 +17,8 @@ class AlpacaPrompter: system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" prompt_style = None - def __init__(self, prompt_style="instruct"): - self.prompt_style = prompt_style + def __init__(self, prompt_style=PromptStyle.instruct.value): + self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value self.match_prompt_style() def match_prompt_style(self): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5b243bec4..de04e9333 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -211,12 +211,12 @@ def load_model( try: if is_llama_derived_model and "LlamaTokenizer" in globals(): tokenizer = LlamaTokenizer.from_pretrained( - model, + base_model_config, trust_remote_code=True if cfg.trust_remote_code is True else False, ) else: tokenizer = getattr(transformers, tokenizer_type).from_pretrained( - model, + base_model_config, trust_remote_code=True if cfg.trust_remote_code is True else False, ) except: diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py new file mode 100644 index 000000000..9bef37406 --- /dev/null +++ b/src/axolotl/utils/validation.py @@ -0,0 +1,10 @@ +def validate_config(cfg): + if cfg.adapter == "qlora": + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is True + pass + # TODO + # MPT 7b + # https://github.com/facebookresearch/bitsandbytes/issues/25 + # no 8bit adamw w bf16