integrate qlora? maybe?
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git
|
transformers @ git+https://github.com/huggingface/transformers.git
|
||||||
|
bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git
|
||||||
attrdict
|
attrdict
|
||||||
fire
|
fire
|
||||||
PyYAML==6.0
|
PyYAML==6.0
|
||||||
black
|
black
|
||||||
bitsandbytes==0.37.2
|
|
||||||
datasets
|
datasets
|
||||||
accelerate>=0.19.0
|
accelerate>=0.19.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
AutoConfig,
|
AutoConfig, BitsAndBytesConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -81,6 +82,16 @@ def load_model(
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
model_kwargs = {}
|
||||||
|
if cfg.adapter == "qlora":
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if cfg.load_4bit and is_llama_derived_model:
|
if cfg.load_4bit and is_llama_derived_model:
|
||||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
@@ -125,6 +136,7 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
# This is a WIP, still an issue with the backward pass
|
# This is a WIP, still an issue with the backward pass
|
||||||
@@ -159,6 +171,7 @@ def load_model(
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
@@ -172,6 +185,7 @@ def load_model(
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(
|
||||||
@@ -184,8 +198,24 @@ def load_model(
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""### Post-processing on the model
|
||||||
|
Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
|
||||||
|
"""
|
||||||
|
if cfg.adapter == "qlora":
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False # freeze the model - train adapters later
|
||||||
|
if param.ndim == 1:
|
||||||
|
# cast the small parameters (e.g. layernorm) to fp32 for stability
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
class CastOutputToFloat(nn.Sequential):
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x).to(torch.float32)
|
||||||
|
|
||||||
|
model.lm_head = CastOutputToFloat(model.lm_head)
|
||||||
|
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
try:
|
try:
|
||||||
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
||||||
@@ -270,7 +300,7 @@ def load_adapter(model, cfg, adapter):
|
|||||||
|
|
||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if adapter == "lora":
|
if adapter == "lora" or adapter == "qlora":
|
||||||
return load_lora(model, cfg)
|
return load_lora(model, cfg)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
return load_llama_adapter(model, cfg)
|
return load_llama_adapter(model, cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user