Compare commits
2 Commits
torch-211-
...
testingci
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e36d3c9f30 | ||
|
|
53614391ed |
@@ -1,351 +0,0 @@
|
||||
"""
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||
"""
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(
|
||||
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
|
||||
):
|
||||
"""
|
||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`):
|
||||
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
||||
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
||||
"""
|
||||
from torch.distributed.tensor import distribute_tensor
|
||||
|
||||
LOG.info("Broadcasting full state dict to all ranks...")
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
meta_sharded_sd = model.state_dict()
|
||||
sharded_sd = {}
|
||||
for param_name, full_tensor in full_sd.items():
|
||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||
if hasattr(sharded_meta_param, "device_mesh"):
|
||||
sharded_param = distribute_tensor(
|
||||
full_tensor,
|
||||
sharded_meta_param.device_mesh,
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
else:
|
||||
sharded_param = full_tensor
|
||||
|
||||
if offload_to_cpu:
|
||||
sharded_param = sharded_param.cpu()
|
||||
|
||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||
del full_tensor
|
||||
full_sd[param_name] = None
|
||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||
end_time = time.time()
|
||||
LOG.debug(
|
||||
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
|
||||
)
|
||||
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
|
||||
return model
|
||||
|
||||
|
||||
def get_state_dict(self, model, unwrap=True):
|
||||
"""
|
||||
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
|
||||
precision.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
A PyTorch model sent through [`Accelerator.prepare`]
|
||||
unwrap (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
|
||||
|
||||
Returns:
|
||||
`dict`: The state dictionary of the model potentially without full precision.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> net = torch.nn.Linear(2, 2)
|
||||
>>> net = accelerator.prepare(net)
|
||||
>>> state_dict = accelerator.get_state_dict(net)
|
||||
```
|
||||
"""
|
||||
from accelerate import DistributedType
|
||||
from accelerate.utils import compare_versions
|
||||
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
|
||||
tp_sharding = (
|
||||
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
|
||||
)
|
||||
if zero3_sharding or tp_sharding:
|
||||
if model.zero_gather_16bit_weights_on_model_save():
|
||||
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
|
||||
raise ImportError(
|
||||
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
|
||||
)
|
||||
state_dict = (
|
||||
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
if tp_sharding
|
||||
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
||||
)
|
||||
else:
|
||||
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
||||
|
||||
state_dict = clone_tensors_for_torch_save(
|
||||
self.unwrap_model(model).state_dict()
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
|
||||
param = param.full_tensor()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
full_state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
with FSDP.state_dict_type(
|
||||
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
|
||||
):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
if unwrap:
|
||||
model = self.unwrap_model(model)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
"""Helper function to process LoRA modules for FSDP2."""
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
log_bias_dtype_mismatch = False
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||
if module.base_layer.bias is not None:
|
||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||
log_bias_dtype_mismatch = True
|
||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||
module.base_layer.weight.dtype
|
||||
)
|
||||
|
||||
for active_adapter in module.active_adapters:
|
||||
if module.lora_A:
|
||||
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_B:
|
||||
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_A:
|
||||
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_B:
|
||||
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_magnitude_vector:
|
||||
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
||||
return log_bias_dtype_mismatch
|
||||
|
||||
|
||||
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
|
||||
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`): The model to prepare
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: Prepared model
|
||||
"""
|
||||
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
|
||||
from accelerate.utils.modeling import get_non_persistent_buffers
|
||||
from peft import PeftModel
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
|
||||
is_type_fsdp = isinstance(model, FSDPModule) or (
|
||||
is_compiled_module(model)
|
||||
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
|
||||
)
|
||||
if is_type_fsdp:
|
||||
return model
|
||||
|
||||
fsdp2_plugin = accelerator.state.fsdp_plugin
|
||||
|
||||
original_sd = model.state_dict()
|
||||
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
||||
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
||||
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "transformer"
|
||||
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "size"
|
||||
|
||||
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
|
||||
# This is because of `apply_activation_checkpointing` which will can reuse this function
|
||||
fsdp2_plugin.set_auto_wrap_policy(model)
|
||||
|
||||
if fsdp2_plugin.activation_checkpointing:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointImpl,
|
||||
apply_activation_checkpointing,
|
||||
checkpoint_wrapper,
|
||||
)
|
||||
|
||||
# Apply activation checkpointing before applying `fully_shard`
|
||||
apply_activation_checkpointing(
|
||||
model,
|
||||
checkpoint_wrapper_fn=functools.partial(
|
||||
checkpoint_wrapper,
|
||||
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||
),
|
||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||
)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
for _, param in model.named_parameters():
|
||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
|
||||
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
model_has_params4bit = True
|
||||
break
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
||||
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
||||
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
||||
|
||||
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
||||
# Also, these buffers aren't getting sharded by default
|
||||
# We get the FQNs of all non-persistent buffers, to re-register them after
|
||||
non_persistent_buffer_fqns = get_non_persistent_buffers(
|
||||
model, recurse=True, fqns=True
|
||||
)
|
||||
original_non_persistent_buffers = copy.deepcopy(
|
||||
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
|
||||
)
|
||||
# We move the model to meta device, as then sharding happens on meta device
|
||||
model = model.to(torch.device("meta"))
|
||||
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
|
||||
# We assume `transformers` models have a `tie_weights` method if they support it
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
for module in get_module_children_bottom_up(model)[:-1]:
|
||||
if is_peft_model and isinstance(module, LoraLayer):
|
||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||
module, fsdp2_kwargs
|
||||
)
|
||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||
fully_shard(module, **fsdp2_kwargs)
|
||||
|
||||
fully_shard(model, **fsdp2_kwargs)
|
||||
|
||||
if log_bias_dtype_mismatch:
|
||||
LOG.warning(
|
||||
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
fsdp2_load_full_state_dict(
|
||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# We re-register the buffers, as they may not be in the state_dict
|
||||
for fqn, buffer_tensor in original_non_persistent_buffers.items():
|
||||
buffer_tensor = buffer_tensor.to(accelerator.device)
|
||||
|
||||
if "." in fqn:
|
||||
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(parent_fqn)
|
||||
else:
|
||||
local_buffer_name = fqn
|
||||
parent_module = model
|
||||
|
||||
parent_module.register_buffer(
|
||||
local_buffer_name, buffer_tensor, persistent=False
|
||||
)
|
||||
|
||||
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
|
||||
# Needs to be called both here and above
|
||||
# removing this call makes the have slightly different loss
|
||||
# removing the call above leads to extra memory usage as explained in the comment above
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
|
||||
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
|
||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate"],
|
||||
"Accelerator.get_state_dict",
|
||||
get_state_dict,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user