Files
axolotl/src/axolotl/monkeypatch/accelerate/fsdp2.py
NanoCode012 682a9cf79b Fix: add delinearization and make qlora work with fsdp2 (#2515)
* fixes for delinearization, and make qlora work with fsdp2

* Add back mistakenly removed lm_eval

* typo [skip ci]

* patch evals for torch.compile + fsdp2

* also check torch_compile w fsdp2

* lots of fixes for flex attn with llama4

* fix patch check and patch llama4 too

* attempt to make the patches stick

* use transformers 4.51.2

* update configs and README for llama4

* remove torch.compile for CI test

* cleanup any existing singletons

* set singleton cache to None instead of deleting

* use importlib reload with monkeypatch

* don't worry about transformers version, mark inputs with grads, fix regex

* make sure embeds aren't on cpu

* logging and mem improvements

* vllm version and add to docker, make sure to save processor on conversion

* fix ambiguous tensor bool check

* fix vllm to not use v1, upgrade hf transformers

* fix tests

* make flex_attn_compile_kwargs configurable, since this depends on model params

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-04-15 23:31:39 -07:00

64 lines
2.3 KiB
Python

"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
"""
import logging
import sys
import torch
LOG = logging.getLogger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to load the state dict into
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...")
sharded_sd = model.state_dict()
param_names = sorted(sharded_sd.keys())
for param_name in param_names:
mesh = sharded_sd[param_name].device_mesh
if accelerator.is_main_process:
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
else:
# Prepare a tensor of matching shape and dtype
full_tensor = torch.empty(
sharded_sd[param_name].size(),
device="cuda",
dtype=sharded_sd[param_name].dtype,
)
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd, assign=True)
def patch_accelerate_fsdp_utils():
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
setattr(
sys.modules["accelerate.utils.fsdp_utils"],
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)