diff --git a/README.md b/README.md index e267a9d6d..225ef0dd7 100644 --- a/README.md +++ b/README.md @@ -264,6 +264,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic bf16: true # require >=ampere fp16: true tf32: true # require >=ampere + bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP + float16: true # use instead of fp16 when you don't want AMP ``` Note: Repo does not do 4-bit quantization. @@ -522,6 +524,12 @@ Add below flag to train command above --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False ``` +If you run out of CUDA memory, you can try to merge in system RAM with + +```bash +CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ... +``` + ## Common Errors 🧰 > Cuda out of memory diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c6d380267..2ae9a26aa 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,13 +11,14 @@ import bitsandbytes as bnb import torch import transformers from optimum.bettertransformer import BetterTransformer -from transformers import PreTrainedModel # noqa: F401 -from transformers import ( +from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, LlamaConfig, + PreTrainedModel, + PreTrainedTokenizerBase, ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN @@ -71,7 +72,7 @@ def load_tokenizer( def load_model( base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" ): - # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ Load a model from a base model and a model type. """ @@ -284,6 +285,7 @@ def load_model( model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=cfg.trust_remote_code or False,