FDSP + QLoRA (#1378)
* wip qlora + fsdp fixes
* more fixes
* make sure to load the lora 🤦
* only setup quantized meta on non-zero rank:
* only run setup_quantized_peft_meta_for_training for qlora+fsdp
* more fixes for qlora+fsdp
* chore: lint
* add example yml
* support mistral too
* fix for model_type and add mixtral support too
* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp
* refactor for duplicate code
This commit is contained in:
0
src/axolotl/core/policies/__init__.py
Normal file
0
src/axolotl/core/policies/__init__.py
Normal file
55
src/axolotl/core/policies/auto_wrap.py
Normal file
55
src/axolotl/core/policies/auto_wrap.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""module for building the auto wrap policy for FSDP"""
|
||||
import functools
|
||||
|
||||
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
_or_policy,
|
||||
lambda_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
||||
|
||||
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
]
|
||||
|
||||
|
||||
def get_wrapping_policy_factory(model_type):
|
||||
if model_type == "llama":
|
||||
layer_to_wrap = LlamaDecoderLayer
|
||||
elif model_type == "mistral":
|
||||
layer_to_wrap = MistralDecoderLayer
|
||||
elif model_type == "mixtral":
|
||||
layer_to_wrap = MixtralDecoderLayer
|
||||
|
||||
def get_wrapping_policy():
|
||||
"""This checks for lora layers (has weight and requires_grad)"""
|
||||
|
||||
def lambda_policy_fn(module):
|
||||
return (
|
||||
len(list(module.named_children())) == 0
|
||||
and getattr(module, "weight", None) is not None
|
||||
and module.weight.requires_grad
|
||||
)
|
||||
|
||||
lambda_policy = functools.partial(
|
||||
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
|
||||
)
|
||||
transformer_layer_name = layer_to_wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls=(
|
||||
PrefixEncoder,
|
||||
PromptEncoder,
|
||||
PromptEmbedding,
|
||||
transformer_layer_name,
|
||||
),
|
||||
)
|
||||
policies = [lambda_policy, transformer_wrap_policy]
|
||||
return functools.partial(_or_policy, policies=policies)
|
||||
|
||||
return get_wrapping_policy
|
||||
@@ -8,6 +8,7 @@ import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
@@ -17,7 +18,10 @@ from typing import List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
from accelerate.utils import str_to_bool
|
||||
from datasets import Dataset
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -30,6 +34,7 @@ from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
@@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -468,6 +477,56 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.args.qlora is False:
|
||||
return res
|
||||
|
||||
# the rest of this method override is specific to fsdp + qlora (for now)
|
||||
sync_module_states = (
|
||||
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
|
||||
)
|
||||
|
||||
mp_policy = None
|
||||
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
|
||||
if amp == "fp16":
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=torch.float32,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
)
|
||||
elif amp == "bf16":
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=torch.float32,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
)
|
||||
|
||||
# If somehow we figure out how we want to parameterize we want to autocast buffers...
|
||||
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
|
||||
# load_param_skip_names = ['inv_freq']
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrapping_policy(),
|
||||
cpu_offload=False,
|
||||
use_orig_params=False,
|
||||
limit_all_gathers=True,
|
||||
param_init_fn=lambda module: module.to_empty(
|
||||
device=torch.device("cuda"), recurse=False
|
||||
)
|
||||
if (rank != 0 and sync_module_states)
|
||||
else None,
|
||||
mixed_precision_policy=mp_policy,
|
||||
)
|
||||
self.accelerator.state.fsdp_plugin = fsdp_plugin
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -787,6 +846,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
|
||||
or not torch.cuda.is_available()
|
||||
or device == "auto"
|
||||
or torch.device(device).type == "cpu"
|
||||
or torch.device(device).type == "meta"
|
||||
):
|
||||
return default_value
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
"""Module for models and model loading"""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
import safetensors
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from bitsandbytes.nn import Linear4bit, Params4bit
|
||||
from fastcore.parallel import parallel
|
||||
from peft import (
|
||||
LoftQConfig,
|
||||
PeftConfig,
|
||||
@@ -16,6 +23,7 @@ from peft import (
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from torch import Tensor, nn
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||
|
||||
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def replace_linear(
|
||||
model: nn.Module,
|
||||
linear_replacement: Type[nn.Module],
|
||||
quant_config: Union[dict, None] = None,
|
||||
skip_modules=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
"""
|
||||
if skip_modules is None:
|
||||
skip_modules = ["lm_head"]
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(
|
||||
module, linear_replacement, quant_config, skip_modules, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
if issubclass(linear_replacement, Linear4bit):
|
||||
model._modules[ # pylint: disable=protected-access
|
||||
name
|
||||
] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported linear replacement: {type(linear_replacement)}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_and_quantize(
|
||||
module: nn.Module,
|
||||
name: str,
|
||||
value: Tensor,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
skip_names: Optional[List[str]] = None,
|
||||
is_meta_rank: bool = False,
|
||||
low_memory: bool = True,
|
||||
verbose: bool = False,
|
||||
quant_method: str = "bnb",
|
||||
):
|
||||
"""
|
||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||
|
||||
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
|
||||
"""
|
||||
|
||||
if skip_names is None:
|
||||
skip_names = []
|
||||
|
||||
def place_on_device(value):
|
||||
if is_meta_rank:
|
||||
device = "meta"
|
||||
elif low_memory:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
return value.to(device=device, dtype=dtype)
|
||||
|
||||
if any(skip_name in name for skip_name in skip_names):
|
||||
if verbose:
|
||||
print(f"Skipping {name} because it is in skip_names")
|
||||
return
|
||||
|
||||
module_key, _, value_key = name.rpartition(".")
|
||||
try:
|
||||
submodule = module.get_submodule(module_key)
|
||||
except AttributeError as exc:
|
||||
print(f"Module {module_key} not found:\n{exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
if quant_method == "bnb":
|
||||
param = submodule.get_parameter(value_key)
|
||||
if isinstance(param, Params4bit):
|
||||
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
||||
# shape as the quantized Params4bit with an initialized quant_state. However,
|
||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||
value = type(param)(
|
||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||
).cuda(device)
|
||||
if is_meta_rank:
|
||||
value = type(param)(value.data.to("meta"), **value.__dict__)
|
||||
elif low_memory:
|
||||
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
||||
else:
|
||||
value = type(param)(place_on_device(value).data)
|
||||
|
||||
except AttributeError:
|
||||
# it's a buffer
|
||||
value = place_on_device(value)
|
||||
|
||||
setattr(submodule, value_key, value)
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@@ -394,7 +515,7 @@ def load_model(
|
||||
|
||||
if max_memory is not None:
|
||||
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from accelerate import infer_auto_device_map
|
||||
|
||||
with init_empty_weights():
|
||||
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
||||
@@ -496,8 +617,78 @@ def load_model(
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
qlora_fsdp = (
|
||||
cfg.fsdp
|
||||
and cfg.adapter == "qlora"
|
||||
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
if qlora_fsdp:
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
|
||||
elif cfg.fp16 or cfg.float16:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
||||
else:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
||||
|
||||
with init_empty_weights():
|
||||
LOG.info("Loading model with empty weights.")
|
||||
model = AutoModelForCausalLM.from_config(model_config)
|
||||
model.model = replace_linear(
|
||||
model.model,
|
||||
Linear4bit,
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=torch_dtype,
|
||||
)
|
||||
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
# Grab the safetensors files that hold the weights
|
||||
try:
|
||||
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
|
||||
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
|
||||
except OSError:
|
||||
try:
|
||||
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
||||
files = []
|
||||
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
|
||||
except OSError as exc:
|
||||
# This means the model probably doesn't have a safetensors file
|
||||
raise exc
|
||||
|
||||
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
||||
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
||||
def load_and_quantize_parallel(name_param, model, **kwargs):
|
||||
name, param = name_param
|
||||
load_and_quantize(model, name, param, **kwargs)
|
||||
|
||||
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
||||
for filename in files:
|
||||
weights = safetensors.torch.load_file(filename)
|
||||
quant_method = "bnb"
|
||||
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
left = int(os.cpu_count() / torch.cuda.device_count())
|
||||
right = int(
|
||||
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
|
||||
)
|
||||
n_workers = min(left, right)
|
||||
parallel(
|
||||
load_and_quantize_parallel,
|
||||
weights.items(),
|
||||
n_workers=n_workers,
|
||||
threadpool=True,
|
||||
model=model,
|
||||
dtype=torch_dtype,
|
||||
device=cfg.local_rank,
|
||||
skip_names=[],
|
||||
is_meta_rank=(cfg.local_rank != 0),
|
||||
verbose=False,
|
||||
quant_method=quant_method,
|
||||
)
|
||||
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
@@ -613,7 +804,7 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
embeddings_len = (
|
||||
@@ -692,6 +883,9 @@ def load_model(
|
||||
if cfg.adapter == "lora" and loftq_bits:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if qlora_fsdp:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
@@ -706,7 +900,7 @@ def load_model(
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if needs_fa2_dtype or cfg.flash_attention:
|
||||
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
@@ -724,7 +918,12 @@ def load_model(
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
||||
if (
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
and not (cfg.rl and cfg.load_in_4bit)
|
||||
and not qlora_fsdp
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
@@ -813,6 +1012,30 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def setup_quantized_meta_for_peft(model: nn.Module):
|
||||
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
|
||||
|
||||
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
return self
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit):
|
||||
param.quant_state._orig_to = ( # pylint: disable=protected-access
|
||||
param.quant_state.to
|
||||
)
|
||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||
|
||||
|
||||
def setup_quantized_peft_meta_for_training(model: nn.Module):
|
||||
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
|
||||
param.quant_state.to = (
|
||||
param.quant_state._orig_to # pylint: disable=protected-access
|
||||
)
|
||||
param.quant_state._orig_to = None # pylint: disable=protected-access
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False, config_only=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||
|
||||
@@ -849,6 +1072,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
model_kwargs: Any = {}
|
||||
@@ -864,6 +1092,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
if rank == 0:
|
||||
model.print_trainable_parameters()
|
||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
Reference in New Issue
Block a user