strip out hacky qlora-fsdp workarounds now that qlora-fsdp fixes are upstreamed (#1428)

This commit is contained in:
Wing Lian
2024-03-21 11:56:13 -04:00
committed by GitHub
parent 7d55607368
commit 2a1589f6f6
8 changed files with 27 additions and 323 deletions

View File

@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 4
optimizer: paged_adamw_8bit optimizer: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.00001 learning_rate: 0.00001
@@ -66,5 +66,11 @@ weight_decay: 0.0
fsdp: fsdp:
- full_shard - full_shard
fsdp_config: fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
special_tokens: special_tokens:

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.9.0 peft==0.9.0
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.43.0 bitsandbytes==0.43.0
accelerate==0.26.1 accelerate==0.28.0
deepspeed==0.13.1 deepspeed==0.13.1
pydantic==2.6.3 pydantic==2.6.3
addict addict
@@ -40,4 +40,3 @@ gcsfs
# adlfs # adlfs
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90 trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
fastcore>=1.5.29

View File

@@ -1,55 +0,0 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

View File

@@ -8,7 +8,6 @@ import importlib
import importlib.util import importlib.util
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
@@ -19,10 +18,7 @@ from typing import Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
@@ -35,7 +31,6 @@ from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
@@ -591,51 +586,14 @@ class AxolotlTrainer(Trainer):
@wraps(Trainer.create_accelerator_and_postprocess) @wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self): def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess() res = super().create_accelerator_and_postprocess()
if self.args.qlora is False:
return res
# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type) if (
fsdp_plugin = FullyShardedDataParallelPlugin( "limit_all_gathers" in self.args.fsdp_config
auto_wrap_policy=wrapping_policy(), and self.args.fsdp_config["limit_all_gathers"]
cpu_offload=False, ):
use_orig_params=False, self.accelerator.state.fsdp_plugin.limit_all_gathers = True
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
return res return res

View File

@@ -5,16 +5,14 @@ import logging
import math import math
import os import os
import types import types
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict import addict
import bitsandbytes as bnb import bitsandbytes as bnb
import safetensors
import torch import torch
import transformers import transformers
from accelerate import init_empty_weights from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit from bitsandbytes.nn import Params4bit
from fastcore.parallel import parallel
from peft import ( from peft import (
LoftQConfig, LoftQConfig,
PeftConfig, PeftConfig,
@@ -23,7 +21,7 @@ from peft import (
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
) )
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
from torch import Tensor, nn from torch import nn
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
AutoConfig, AutoConfig,
@@ -35,9 +33,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES, SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -272,117 +268,6 @@ def load_tokenizer(cfg):
return tokenizer return tokenizer
def replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
is_meta_rank: bool = False,
low_memory: bool = True,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
"""
if skip_names is None:
skip_names = []
def place_on_device(value):
if is_meta_rank:
device = "meta"
elif low_memory:
device = "cpu"
else:
device = "cuda"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if is_meta_rank:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif low_memory:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def load_model( def load_model(
cfg: DictDefault, cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
@@ -568,6 +453,7 @@ def load_model(
"bnb_4bit_compute_dtype": cfg.torch_dtype, "bnb_4bit_compute_dtype": cfg.torch_dtype,
"bnb_4bit_use_double_quant": True, "bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
} }
if cfg.bnb_config_kwargs: if cfg.bnb_config_kwargs:
@@ -617,78 +503,10 @@ def load_model(
model_kwargs["attn_implementation"] = "eager" model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access model_config._attn_implementation = "eager" # pylint: disable=protected-access
qlora_fsdp = ( qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
cfg.fsdp
and cfg.adapter == "qlora"
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
)
try: try:
if qlora_fsdp: if (
if cfg.bf16 or cfg.bfloat16:
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
elif cfg.fp16 or cfg.float16:
torch_dtype, compute_dtype = torch.float32, torch.float16
else:
torch_dtype, compute_dtype = torch.float32, torch.float16
with init_empty_weights():
LOG.info("Loading model with empty weights.")
model = AutoModelForCausalLM.from_config(model_config)
model.model = replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=torch_dtype,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
param_count = sum((p.numel() for n, p in model.named_parameters()))
for filename in files:
weights = safetensors.torch.load_file(filename)
quant_method = "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
right = int(
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
)
n_workers = min(left, right)
parallel(
load_and_quantize_parallel,
weights.items(),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=torch_dtype,
device=cfg.local_rank,
skip_names=[],
is_meta_rank=(cfg.local_rank != 0),
verbose=False,
quant_method=quant_method,
)
elif (
model_config.model_type == "llama" model_config.model_type == "llama"
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
@@ -715,32 +533,6 @@ def load_model(
if cfg.flash_attn_fuse_qkv: if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV") LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model) replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
# TODO: try config.sequence_parallel = False
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
# from flash_attn.utils.pretrained import state_dict_from_pretrained
# from flash_attn.models.gpt import GPTLMHeadModel
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
# from transformers import GPTNeoXConfig
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
# config.use_flash_attn = True
# config.fused_bias_fc = True
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
# config.activation_function = "gelu_fast"
# config.fused_dropout_add_ln = True
# # config.residual_in_fp32 = True
#
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
# base_model,
# config,
# dtype=torch_dtype,
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif model_type == "MambaLMHeadModel": elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work # FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name

View File

@@ -304,6 +304,10 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_OFFLOAD_PARAMS"] = "true" os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states: if cfg.fsdp_config.fsdp_sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true" os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
if cfg.fsdp_config.fsdp_use_orig_params:
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_state_dict_type: if cfg.fsdp_config.fsdp_state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:

View File

@@ -77,7 +77,7 @@ class TestMixtral(unittest.TestCase):
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.uint8 == torch.float32
) )
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.uint8 == torch.float32
) )
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()