diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d20db7065..30d4774db 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,11 +23,6 @@ jobs: python_version: "3.10" pytorch: 2.0.1 axolotl_extras: - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.9" - pytorch: 2.0.1 - axolotl_extras: gptq runs-on: self-hosted steps: - name: Checkout @@ -73,11 +68,6 @@ jobs: pytorch: 2.0.1 axolotl_extras: is_latest: true - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.9" - pytorch: 2.0.1 - axolotl_extras: gptq runs-on: self-hosted steps: - name: Checkout diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 01703cd51..d5184def6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | - pip install -e .[peft] + pip install -e . pip install -r requirements-tests.txt - name: Run tests diff --git a/docker/Dockerfile b/docker/Dockerfile index b429d50f2..683ca75ff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,14 +11,13 @@ RUN apt-get update && \ WORKDIR /workspace -RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git # If AXOLOTL_EXTRAS is set, append it in brackets RUN cd axolotl && \ if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ + pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \ else \ - pip install -e .[flash-attn]; \ + pip install -e .[flash-attn,gptq]; \ fi # fix so that git fetch/pull from remote works diff --git a/examples/gptq-lora-7b/README.md b/examples/gptq-lora-7b/README.md deleted file mode 100644 index 0bde51b06..000000000 --- a/examples/gptq-lora-7b/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# LLaMa 7B using LoRA - -This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed. - -```shell -accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml - -``` diff --git a/examples/gptq-lora-7b/config.yml b/examples/gptq-lora-7b/config.yml deleted file mode 100644 index d909f7d07..000000000 --- a/examples/gptq-lora-7b/config.yml +++ /dev/null @@ -1,63 +0,0 @@ -base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g -base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -trust_remote_code: -load_in_8bit: true -gptq: true -datasets: - - path: vicgalle/alpaca-gpt4 - type: alpaca -dataset_prepared_path: last_run_prepared -val_set_size: 0.02 -adapter: -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - q_proj - - v_proj -lora_fan_in_fan_out: false -wandb_project: llama-7b-lora-int4 -wandb_entity: -wandb_watch: -wandb_run_id: -wandb_log_model: -output_dir: ./llama-7b-lora-int4 -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 3 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0000002 -train_on_inputs: false -group_by_length: false -fp16: true -bf16: false -tf32: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 5 -xformers_attention: -flash_attention: -gradient_checkpointing: true -gptq_groupsize: 128 -gptq_model_v1: false -warmup_steps: 20 -eval_steps: 110 -save_steps: 660 -debug: -deepspeed: -weight_decay: 0.0001 -fsdp: -fsdp_config: -tokens: - pad_token: "" - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml new file mode 100644 index 000000000..dbce2a6b3 --- /dev/null +++ b/examples/llama-2/gptq-lora.yml @@ -0,0 +1,76 @@ +base_model: TheBloke/Llama-2-7B-GPTQ +base_model_config: TheBloke/Llama-2-7B-GPTQ +is_llama_derived_model: false +gptq: true +gptq_bits: 4 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer +tokenizer_use_fast: true +tokenizer_legacy: true +load_in_8bit: false +load_in_4bit: false +strict: false +push_dataset_to_hub: +hf_use_auth_token: true +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +adapter: lora +lora_model_dir: +sequence_len: 4096 +sample_packing: +lora_r: 8 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_modules: + - k_proj + - o_proj + - q_proj + - v_proj +lora_target_linear: +lora_fan_in_fan_out: +wandb_project: +wandb_watch: +wandb_run_id: +wandb_log_model: +output_dir: ./model-out +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 3 +optimizer: adamw_torch +adam_beta2: 0.95 +adam_eps: 0.00001 +max_grad_norm: 1.0 +torchdistx_path: +lr_scheduler: cosine +lr_quadratic_warmup: true +learning_rate: 0.000017 +train_on_inputs: false +group_by_length: false +bf16: false +fp16: false +float16: true +tf32: true +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: +sdp_attention: +flash_optimum: +gptq_groupsize: +gptq_model_v1: +warmup_steps: 100 +eval_steps: +save_steps: +debug: +deepspeed: +weight_decay: 0.1 +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/requirements.txt b/requirements.txt index fcd7f9292..1c8e97dff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ +torch==2.0.1 +auto-gptq packaging peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git diff --git a/setup.py b/setup.py index 7b99794de..973d656cd 100644 --- a/setup.py +++ b/setup.py @@ -2,15 +2,27 @@ from setuptools import find_packages, setup -install_requires = [] -with open("./requirements.txt", encoding="utf-8") as requirements_file: - # don't include peft yet until we check the int4 - # need to manually install peft for now... - reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r] - reqs = [r for r in reqs if "flash-attn" not in r] - reqs = [r for r in reqs if r and r[0] != "#"] - for r in reqs: - install_requires.append(r) + +def parse_requirements(): + _install_requires = [] + _dependency_links = [] + with open("./requirements.txt", encoding="utf-8") as requirements_file: + lines = [ + r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r + ] + for line in lines: + if line.startswith("--extra-index-url"): + # Handle custom index URLs + _, url = line.split() + _dependency_links.append(url) + elif "flash-attn" not in line and line and line[0] != "#": + # Handle standard packages + _install_requires.append(line) + return _install_requires, _dependency_links + + +install_requires, dependency_links = parse_requirements() + setup( name="axolotl", @@ -19,12 +31,10 @@ setup( package_dir={"": "src"}, packages=find_packages(), install_requires=install_requires, + dependency_links=dependency_links, extras_require={ "gptq": [ - "alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip", - ], - "gptq_triton": [ - "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip", + "auto-gptq", ], "flash-attn": [ "flash-attn==2.0.8", @@ -32,8 +42,5 @@ setup( "extras": [ "deepspeed", ], - "peft": [ - "peft @ git+https://github.com/huggingface/peft.git", - ], }, ) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 93a23f773..0fbccd205 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -108,9 +108,7 @@ def validate_config(cfg): "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", ) if cfg.load_4bit: - raise ValueError( - "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" - ) + raise ValueError("cfg.load_4bit parameter has been deprecated") if cfg.adapter == "qlora": if cfg.merge_lora: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9f0795af7..9ec51f4f7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -4,19 +4,19 @@ import logging import math import os -from pathlib import Path from typing import Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers from optimum.bettertransformer import BetterTransformer -from peft import PeftConfig +from peft import PeftConfig, prepare_model_for_kbit_training from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, + GPTQConfig, LlamaConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -155,32 +155,17 @@ def load_model( LOG.info("patching _expand_mask") hijack_expand_mask() - try: - if cfg.gptq: - from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( - replace_peft_model_with_int4_lora_model, - ) - - replace_peft_model_with_int4_lora_model() - except Exception as err: - LOG.exception(err) - raise err - - if not cfg.gptq and ( - (cfg.adapter == "lora" and load_in_8bit) - or (cfg.adapter == "qlora" and cfg.load_in_4bit) - ): - try: - from peft import prepare_model_for_kbit_training - except ImportError: - # For backward compatibility - from peft import ( - prepare_model_for_int8_training as prepare_model_for_kbit_training, - ) - model_kwargs = {} if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision + if cfg.gptq: + model_config = load_model_config(cfg) + if hasattr(model_config, "quantization_config"): + LOG.warning("model config does not contain quantization_config information") + else: + model_kwargs["quantization_config"] = GPTQConfig( + **model_config.quantization_config + ) if cfg.adapter == "qlora" and cfg.load_in_4bit: model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, @@ -191,45 +176,7 @@ def load_model( bnb_4bit_quant_type="nf4", ) try: - if cfg.gptq and cfg.is_llama_derived_model: - from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram - from huggingface_hub import snapshot_download - - try: - snapshot_download_kwargs = {} - if cfg.base_model_ignore_patterns: - snapshot_download_kwargs[ - "ignore_patterns" - ] = cfg.base_model_ignore_patterns - cache_model_path = Path( - snapshot_download(base_model, **snapshot_download_kwargs) - ) - files = ( - list(cache_model_path.glob("*.pt")) - + list(cache_model_path.glob("*.safetensors")) - + list(cache_model_path.glob("*.bin")) - ) - if len(files) > 0: - model_path = str(files[0]) - else: - LOG.warning( - "unable to find a cached model file, this will likely fail..." - ) - model_path = str(cache_model_path) - except Exception: # pylint: disable=broad-exception-caught - model_path = cfg.base_model - model, _ = load_llama_model_4bit_low_ram( - base_model_config if base_model_config else base_model, - model_path, - device_map=cfg.device_map, - half=cfg.fp16, - groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1, - is_v1_model=cfg.gptq_model_v1 - if cfg.gptq_model_v1 is not None - else True, - ) - load_in_8bit = False - elif cfg.is_llama_derived_model and not cfg.trust_remote_code: + if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM config_kwargs = {} @@ -275,15 +222,24 @@ def load_model( # ) # model.train() # sets to train instead of eval mode elif model_type and not cfg.trust_remote_code: - model = getattr(transformers, model_type).from_pretrained( - base_model, - device_map=cfg.device_map, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, - ) + if cfg.gptq: + model = AutoModelForCausalLM.from_pretrained( + base_model, + device_map=cfg.device_map, + torch_dtype=cfg.torch_dtype, + trust_remote_code=cfg.trust_remote_code or False, + **model_kwargs, + ) + else: + model = getattr(transformers, model_type).from_pretrained( + base_model, + device_map=cfg.device_map, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, + torch_dtype=cfg.torch_dtype, + trust_remote_code=cfg.trust_remote_code or False, + **model_kwargs, + ) else: config = AutoConfig.from_pretrained( base_model, @@ -359,11 +315,12 @@ def load_model( module.to(torch.float32) needs_fa2_dtype = cfg.adapter or cfg.fsdp - if not cfg.gptq and ( - (cfg.adapter == "lora" and load_in_8bit) - or (cfg.adapter == "qlora" and cfg.load_in_4bit) + if (cfg.adapter == "lora" and load_in_8bit) or ( + cfg.adapter == "qlora" and cfg.load_in_4bit ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") + if cfg.gradient_checkpointing: + model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) @@ -385,22 +342,10 @@ def load_model( if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}") - if cfg.gptq: - # Scales to half - LOG.info("Fitting 4bit scales and zeros to half") - for _, module in model.named_modules(): - if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str( - type(module) - ): - if hasattr(module, "is_v1_model") and module.is_v1_model: - module.zeros = module.zeros.half() - module.scales = module.scales.half() - module.bias = module.bias.half() - if ( torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 - and (cfg.gptq or cfg.load_in_4bit) + and (cfg.load_in_4bit) ): # llama is PROBABLY model parallelizable, but the default isn't that it is # so let's only set it for the 4bit, see diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f91f4e318..c3d6b85cb 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -514,23 +514,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ training_arguments_kwargs["seed"] = cfg.seed if cfg.gradient_checkpointing: - if cfg.gptq: - from alpaca_lora_4bit.gradient_checkpointing import ( - apply_gradient_checkpointing, - ) - - gradient_checkpointing_ratio = ( - cfg.gradient_checkpointing_ratio - if cfg.gradient_checkpointing_ratio - else 1.0 - ) - apply_gradient_checkpointing( - model, checkpoint_ratio=gradient_checkpointing_ratio - ) - else: - training_arguments_kwargs[ - "gradient_checkpointing" - ] = cfg.gradient_checkpointing + training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing if cfg.fsdp: training_arguments_kwargs["fsdp"] = cfg.fsdp if cfg.fsdp_config: