* fixes for delinearization, and make qlora work with fsdp2 * Add back mistakenly removed lm_eval * typo [skip ci] * patch evals for torch.compile + fsdp2 * also check torch_compile w fsdp2 * lots of fixes for flex attn with llama4 * fix patch check and patch llama4 too * attempt to make the patches stick * use transformers 4.51.2 * update configs and README for llama4 * remove torch.compile for CI test * cleanup any existing singletons * set singleton cache to None instead of deleting * use importlib reload with monkeypatch * don't worry about transformers version, mark inputs with grads, fix regex * make sure embeds aren't on cpu * logging and mem improvements * vllm version and add to docker, make sure to save processor on conversion * fix ambiguous tensor bool check * fix vllm to not use v1, upgrade hf transformers * fix tests * make flex_attn_compile_kwargs configurable, since this depends on model params --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
"""
|
|
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
|
|
"""
|
|
|
|
import logging
|
|
import sys
|
|
|
|
import torch
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
|
"""
|
|
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
|
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
|
|
|
Args:
|
|
accelerator (`Accelerator`): The accelerator instance
|
|
model (`torch.nn.Module`): The model to load the state dict into
|
|
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
|
"""
|
|
import torch.distributed as dist
|
|
from torch.distributed.tensor import distribute_tensor
|
|
|
|
LOG.info("Broadcasting full state dict to all ranks...")
|
|
sharded_sd = model.state_dict()
|
|
param_names = sorted(sharded_sd.keys())
|
|
for param_name in param_names:
|
|
mesh = sharded_sd[param_name].device_mesh
|
|
if accelerator.is_main_process:
|
|
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
|
|
full_param = full_sd[param_name].detach().cuda()
|
|
dist.broadcast(full_param, src=0, group=mesh.get_group())
|
|
sharded_tensor = distribute_tensor(
|
|
full_param, mesh, sharded_sd[param_name].placements
|
|
)
|
|
sharded_sd[param_name] = sharded_tensor
|
|
else:
|
|
# Prepare a tensor of matching shape and dtype
|
|
full_tensor = torch.empty(
|
|
sharded_sd[param_name].size(),
|
|
device="cuda",
|
|
dtype=sharded_sd[param_name].dtype,
|
|
)
|
|
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
|
|
sharded_tensor = distribute_tensor(
|
|
full_tensor, mesh, sharded_sd[param_name].placements
|
|
)
|
|
sharded_sd[param_name] = sharded_tensor
|
|
|
|
model.load_state_dict(sharded_sd, assign=True)
|
|
|
|
|
|
def patch_accelerate_fsdp_utils():
|
|
from accelerate.utils import fsdp_utils
|
|
|
|
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
|
setattr(
|
|
sys.modules["accelerate.utils.fsdp_utils"],
|
|
"fsdp2_load_full_state_dict",
|
|
fsdp2_load_full_state_dict,
|
|
)
|