From dd0065773a3365891e142cd9fb3c1fcf3af25a2d Mon Sep 17 00:00:00 2001 From: Thytu Date: Sat, 27 May 2023 12:36:03 +0000 Subject: [PATCH] refactor(param): rename load_4bit config param by gptq Signed-off-by: Thytu --- README.md | 2 +- configs/quickstart.yml | 2 +- examples/4bit-lora-7b/config.yml | 2 +- src/axolotl/utils/models.py | 10 +++++----- src/axolotl/utils/trainer.py | 4 ++-- src/axolotl/utils/validation.py | 8 ++++++-- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index ae8e3e2c0..32817b709 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ tokenizer_type: AutoTokenizer trust_remote_code: # whether you are training a 4-bit GPTQ quantized model -load_4bit: true +gptq: true gptq_groupsize: 128 # group size gptq_model_v1: false # v1 or v2 diff --git a/configs/quickstart.yml b/configs/quickstart.yml index a2cbdff4d..f29df1478 100644 --- a/configs/quickstart.yml +++ b/configs/quickstart.yml @@ -40,6 +40,6 @@ early_stopping_patience: 3 resume_from_checkpoint: auto_resume_from_checkpoints: true local_rank: -load_4bit: true +gptq: true xformers_attention: true flash_attention: diff --git a/examples/4bit-lora-7b/config.yml b/examples/4bit-lora-7b/config.yml index 345e0812e..0d57a6d8d 100644 --- a/examples/4bit-lora-7b/config.yml +++ b/examples/4bit-lora-7b/config.yml @@ -4,7 +4,7 @@ model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer trust_remote_code: load_in_8bit: true -load_4bit: true +gptq: true datasets: - path: vicgalle/alpaca-gpt4 type: alpaca diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 405c9e4b2..721584888 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -73,7 +73,7 @@ def load_model( else: torch_dtype = torch.float32 try: - if cfg.load_4bit: + if cfg.gptq: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( replace_peft_model_with_int4_lora_model, ) @@ -95,7 +95,7 @@ def load_model( bnb_4bit_quant_type="nf4", ) try: - if cfg.load_4bit and is_llama_derived_model: + if cfg.gptq and is_llama_derived_model: from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from huggingface_hub import snapshot_download @@ -248,7 +248,7 @@ def load_model( if ( ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") - and not cfg.load_4bit + and not cfg.gptq and (load_in_8bit or cfg.load_in_4bit) ): logging.info("converting PEFT model w/ prepare_model_for_int8_training") @@ -259,7 +259,7 @@ def load_model( if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}") - if cfg.load_4bit: + if cfg.gptq: # Scales to half logging.info("Fitting 4bit scales and zeros to half") for n, m in model.named_modules(): @@ -274,7 +274,7 @@ def load_model( if ( torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 - and cfg.load_4bit + and cfg.gptq ): # llama is PROBABLY model parallelizable, but the default isn't that it is # so let's only set it for the 4bit, see diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 285075109..cb67eac7d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -63,7 +63,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps if cfg.gradient_checkpointing is not None: - if cfg.load_4bit: + if cfg.gptq: from alpaca_lora_4bit.gradient_checkpointing import ( apply_gradient_checkpointing, ) @@ -138,7 +138,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): importlib.import_module("torchdistx") if ( cfg.optimizer == "adamw_bnb_8bit" - and not cfg.load_4bit + and not cfg.gptq and not "deepspeed" in training_arguments_kwargs and not cfg.fsdp ): diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 2fe7f99db..d56f28f6d 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -2,16 +2,20 @@ import logging def validate_config(cfg): + if cfg.load_4bit: + raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq") + if cfg.adapter == "qlora": if cfg.merge_lora: # can't merge qlora if loaded in 8bit or 4bit assert cfg.load_in_8bit is False - assert cfg.load_4bit is False + assert cfg.gptq is False assert cfg.load_in_4bit is False else: assert cfg.load_in_8bit is False - assert cfg.load_4bit is False + assert cfg.gptq is False assert cfg.load_in_4bit is True + if not cfg.load_in_8bit and cfg.adapter == "lora": logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")