FSDP1 -> FSDP2 (#2760)
* FSDP2 args migration implementation This commit implements the migration to FSDP2 arguments including: - FSDP2 support with LoRA training - DPO integration with FSDP2 - Model loading fixes and refactoring - CPU offloading and PEFT handling - Test updates and CI improvements - Bug fixes for dtype errors and various edge cases
This commit is contained in:
@@ -16,6 +16,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
migrate_fsdp_config,
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
validate_config,
|
||||
@@ -226,6 +227,7 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
@@ -501,6 +501,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
|
||||
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
|
||||
|
||||
self._configure_reporting(training_args_kwargs)
|
||||
self._configure_hub_parameters(training_args_kwargs)
|
||||
self._configure_scheduler(training_args_kwargs)
|
||||
|
||||
@@ -151,14 +151,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps
|
||||
)
|
||||
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = {
|
||||
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||
}
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
|
||||
@@ -208,7 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
@@ -218,21 +218,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# TODO: build PPOConfig
|
||||
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
||||
|
||||
@@ -14,5 +14,4 @@ from .trl import (
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
"""Module for TRL PPO trainer"""
|
||||
"""Module for TRL RL trainers"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
@@ -16,64 +13,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
|
||||
def train(
|
||||
self,
|
||||
reward_pipe,
|
||||
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
||||
):
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
response_tensors, ref_response_tensors = self.generate(
|
||||
query_tensors,
|
||||
return_prompt=False,
|
||||
generate_ref_response=True,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
||||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
||||
ref_rewards = [
|
||||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
||||
]
|
||||
batch["ref_rewards"] = ref_rewards
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(query_tensors, response_tensors, rewards)
|
||||
self.log_stats(
|
||||
stats,
|
||||
batch,
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
||||
):
|
||||
|
||||
@@ -122,9 +122,9 @@ def load_lora(
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if (
|
||||
cfg.fsdp
|
||||
cfg.fsdp_config
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_meta_for_peft(model)
|
||||
@@ -152,9 +152,9 @@ def load_lora(
|
||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||
)
|
||||
elif (
|
||||
cfg.fsdp
|
||||
cfg.fsdp_config
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
@@ -140,10 +140,15 @@ class ModelLoader:
|
||||
"""Check if flash attention is installed."""
|
||||
return find_spec("flash_attn") is not None
|
||||
|
||||
@cached_property
|
||||
def qlora_fsdp(self):
|
||||
@property
|
||||
def is_fsdp_enabled(self):
|
||||
"""Property that determines if FSDP is enabled."""
|
||||
return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None
|
||||
|
||||
@property
|
||||
def is_qlora_and_fsdp_enabled(self):
|
||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
||||
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
|
||||
|
||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||
"""Load and prepare the model with all configurations and patches.
|
||||
@@ -189,15 +194,15 @@ class ModelLoader:
|
||||
# Handle PeftModel if needed
|
||||
if (
|
||||
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
||||
and not self.qlora_fsdp
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
):
|
||||
self.model = self.model.merge_and_unload()
|
||||
|
||||
self._resize_token_embeddings()
|
||||
self._adjust_model_config()
|
||||
self._log_memory_usage()
|
||||
self._configure_embedding_dtypes()
|
||||
self._configure_qat()
|
||||
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
@@ -251,22 +256,13 @@ class ModelLoader:
|
||||
):
|
||||
self.model.config.eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def _log_memory_usage(self):
|
||||
"""Log device memory usage after model load."""
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
"cuda",
|
||||
"mps",
|
||||
"npu",
|
||||
):
|
||||
log_gpu_memory_usage(LOG, "after model load", self.model.device)
|
||||
|
||||
def _configure_embedding_dtypes(self):
|
||||
"""Configure embedding module dtypes."""
|
||||
# Get embedding modules
|
||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||
|
||||
# Initial dtype conversion
|
||||
if not self.cfg.fsdp:
|
||||
if not self.is_fsdp_enabled:
|
||||
# We don't run this during FSDP because this will leave mixed and bfloat16
|
||||
# dtypes in the model which FSDP doesn't like
|
||||
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
||||
@@ -282,7 +278,7 @@ class ModelLoader:
|
||||
self._set_z3_leaf_modules()
|
||||
|
||||
# Apply gradient checkpointing if needed
|
||||
needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp
|
||||
needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled
|
||||
if self.cfg.adapter in ["lora", "qlora"]:
|
||||
needs_fa2_dtype = True
|
||||
if self.cfg.gradient_checkpointing:
|
||||
@@ -298,10 +294,12 @@ class ModelLoader:
|
||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
(
|
||||
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
|
||||
and not self.qlora_fsdp
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
)
|
||||
or (
|
||||
# CCE requires embedding layers to be in fp16/bf16 for backward pass
|
||||
self.cfg.cut_cross_entropy
|
||||
)
|
||||
# CCE requires embedding layers to be in fp16/bf16 for backward pass
|
||||
or self.cfg.cut_cross_entropy
|
||||
)
|
||||
|
||||
if should_convert:
|
||||
@@ -357,7 +355,6 @@ class ModelLoader:
|
||||
and not (self.cfg.rl and self.cfg.load_in_4bit)
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO: validate this conditional
|
||||
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
||||
|
||||
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
@@ -430,7 +427,17 @@ class ModelLoader:
|
||||
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
is_ds_zero3 = is_deepspeed_zero3_enabled()
|
||||
|
||||
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
|
||||
if self.is_fsdp_enabled:
|
||||
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
|
||||
if self.is_qlora_and_fsdp_enabled:
|
||||
self.model_kwargs["device_map"] = {
|
||||
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||
}
|
||||
# For other FSDP cases, don't set device_map at all
|
||||
elif not is_ds_zero3:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
|
||||
cur_device = get_device_type()
|
||||
@@ -499,7 +506,7 @@ class ModelLoader:
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||
self.cfg.deepspeed or self.cfg.fsdp
|
||||
self.cfg.deepspeed or self.is_fsdp_enabled
|
||||
):
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
@@ -604,9 +611,21 @@ class ModelLoader:
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
||||
if (
|
||||
"device_map" in self.model_kwargs
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
):
|
||||
del self.model_kwargs["device_map"]
|
||||
elif self.is_qlora_and_fsdp_enabled:
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
self.qlora_fsdp
|
||||
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and (
|
||||
self.cfg.model_config_type == "dbrx"
|
||||
or self.cfg.qlora_sharded_model_loading
|
||||
@@ -632,12 +651,6 @@ class ModelLoader:
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
# TODO: Do we need to open this up for all models?
|
||||
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"]
|
||||
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
|
||||
@@ -691,33 +704,22 @@ class ModelLoader:
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
if self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.cfg.fsdp
|
||||
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
):
|
||||
# disabling either of these two still leads to VRAM spike before setting back down
|
||||
skip_move_to_device = True
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"]
|
||||
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
@@ -753,8 +755,8 @@ class ModelLoader:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (
|
||||
self.qlora_fsdp
|
||||
or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading)
|
||||
or is_deepspeed_zero3_enabled()
|
||||
):
|
||||
# Make sure everything is in the same dtype
|
||||
|
||||
@@ -94,10 +94,14 @@ class PatchManager:
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
|
||||
patch_accelerate_fsdp2()
|
||||
if self.cfg.rl:
|
||||
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||
|
||||
patch_trl_prepare_fsdp2()
|
||||
|
||||
# if self.cfg.fsdp_config:
|
||||
# # see transformers#39152
|
||||
|
||||
@@ -195,9 +195,11 @@ def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16):
|
||||
bias_mismatch = module.bias.dtype != dtype
|
||||
|
||||
if weight_mismatch:
|
||||
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
|
||||
LOG.debug(
|
||||
f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}"
|
||||
)
|
||||
if bias_mismatch:
|
||||
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
|
||||
LOG.debug(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
|
||||
if weight_mismatch or bias_mismatch:
|
||||
module.to(dtype)
|
||||
|
||||
|
||||
@@ -2,102 +2,65 @@
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||
"""
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
||||
def fsdp2_load_full_state_dict(
|
||||
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
|
||||
):
|
||||
"""
|
||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`):
|
||||
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
||||
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.tensor import distribute_tensor
|
||||
|
||||
# Model was previously copied to meta device
|
||||
LOG.info("Broadcasting full state dict to all ranks...")
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
meta_sharded_sd = model.state_dict()
|
||||
sharded_sd = {}
|
||||
|
||||
# Rank 0 distributes the full state dict to other ranks
|
||||
def _infer_parameter_dtype(model, param_name, empty_param):
|
||||
try:
|
||||
old_param = model.get_parameter_or_buffer(param_name)
|
||||
except AttributeError:
|
||||
# Need this for LORA, as there some params are not *parameters* of sorts
|
||||
base_param_name, local_param_name = param_name.rsplit(".", 1)
|
||||
submodule = model.get_submodule(base_param_name)
|
||||
old_param = getattr(submodule, local_param_name)
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
casting_dtype = None
|
||||
is_param_float8_e4m3fn = (
|
||||
is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
casting_dtype = old_param.dtype
|
||||
|
||||
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
||||
|
||||
def _cast_and_contiguous(tensor, to_contiguous, dtype):
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if to_contiguous:
|
||||
tensor = tensor.contiguous()
|
||||
return tensor
|
||||
|
||||
param_names = sorted(meta_sharded_sd.keys())
|
||||
|
||||
for param_name in param_names:
|
||||
mesh = meta_sharded_sd[param_name].device_mesh
|
||||
if accelerator.is_main_process:
|
||||
full_param = full_sd[param_name].detach().cuda()
|
||||
dist.broadcast(full_param, src=0, group=mesh.get_group())
|
||||
sharded_tensor = distribute_tensor(
|
||||
full_param, mesh, sharded_sd[param_name].placements
|
||||
)
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
param_name,
|
||||
full_param,
|
||||
)
|
||||
sharded_tensor = _cast_and_contiguous(
|
||||
sharded_tensor, to_contiguous, casting_dtype
|
||||
)
|
||||
sharded_sd[param_name] = sharded_tensor
|
||||
else:
|
||||
full_tensor = torch.empty(
|
||||
sharded_sd[param_name].size(),
|
||||
device="cuda",
|
||||
dtype=sharded_sd[param_name].dtype,
|
||||
)
|
||||
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
|
||||
sharded_tensor = distribute_tensor(
|
||||
full_tensor, mesh, sharded_sd[param_name].placements
|
||||
)
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
param_name,
|
||||
for param_name, full_tensor in full_sd.items():
|
||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||
if hasattr(sharded_meta_param, "device_mesh"):
|
||||
sharded_param = distribute_tensor(
|
||||
full_tensor,
|
||||
sharded_meta_param.device_mesh,
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
sharded_tensor = _cast_and_contiguous(
|
||||
sharded_tensor, to_contiguous, casting_dtype
|
||||
)
|
||||
sharded_sd[param_name] = sharded_tensor
|
||||
else:
|
||||
sharded_param = full_tensor
|
||||
|
||||
# we set `assign=True` because our params are on meta device
|
||||
model.load_state_dict(sharded_sd, assign=True)
|
||||
if offload_to_cpu:
|
||||
sharded_param = sharded_param.cpu()
|
||||
|
||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||
del full_tensor
|
||||
full_sd[param_name] = None
|
||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||
end_time = time.time()
|
||||
LOG.debug(
|
||||
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
|
||||
)
|
||||
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
|
||||
return model
|
||||
|
||||
|
||||
@@ -191,17 +154,195 @@ def get_state_dict(self, model, unwrap=True):
|
||||
return state_dict
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
from accelerate.utils import fsdp_utils
|
||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
"""Helper function to process LoRA modules for FSDP2."""
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate.utils.fsdp_utils"],
|
||||
"fsdp2_load_full_state_dict",
|
||||
fsdp2_load_full_state_dict,
|
||||
log_bias_dtype_mismatch = False
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||
if module.base_layer.bias is not None:
|
||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||
log_bias_dtype_mismatch = True
|
||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||
module.base_layer.weight.dtype
|
||||
)
|
||||
|
||||
for active_adapter in module.active_adapters:
|
||||
if module.lora_A:
|
||||
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_B:
|
||||
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_A:
|
||||
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_embedding_B:
|
||||
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
||||
if module.lora_magnitude_vector:
|
||||
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
||||
return log_bias_dtype_mismatch
|
||||
|
||||
|
||||
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
|
||||
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`): The model to prepare
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: Prepared model
|
||||
"""
|
||||
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
|
||||
from accelerate.utils.modeling import get_non_persistent_buffers
|
||||
from peft import PeftModel
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
|
||||
is_type_fsdp = isinstance(model, FSDPModule) or (
|
||||
is_compiled_module(model)
|
||||
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
|
||||
)
|
||||
if is_type_fsdp:
|
||||
return model
|
||||
|
||||
fsdp2_plugin = accelerator.state.fsdp_plugin
|
||||
|
||||
original_sd = model.state_dict()
|
||||
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
||||
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
||||
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "transformer"
|
||||
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "size"
|
||||
|
||||
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
|
||||
# This is because of `apply_activation_checkpointing` which will can reuse this function
|
||||
fsdp2_plugin.set_auto_wrap_policy(model)
|
||||
|
||||
if fsdp2_plugin.activation_checkpointing:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointImpl,
|
||||
apply_activation_checkpointing,
|
||||
checkpoint_wrapper,
|
||||
)
|
||||
|
||||
# Apply activation checkpointing before applying `fully_shard`
|
||||
apply_activation_checkpointing(
|
||||
model,
|
||||
checkpoint_wrapper_fn=functools.partial(
|
||||
checkpoint_wrapper,
|
||||
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||
),
|
||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||
)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
for _, param in model.named_parameters():
|
||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
|
||||
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
model_has_params4bit = True
|
||||
break
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
||||
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
||||
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
||||
|
||||
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
||||
# Also, these buffers aren't getting sharded by default
|
||||
# We get the FQNs of all non-persistent buffers, to re-register them after
|
||||
non_persistent_buffer_fqns = get_non_persistent_buffers(
|
||||
model, recurse=True, fqns=True
|
||||
)
|
||||
original_non_persistent_buffers = copy.deepcopy(
|
||||
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
|
||||
)
|
||||
# We move the model to meta device, as then sharding happens on meta device
|
||||
model = model.to(torch.device("meta"))
|
||||
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
|
||||
# We assume `transformers` models have a `tie_weights` method if they support it
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
for module in get_module_children_bottom_up(model)[:-1]:
|
||||
if is_peft_model and isinstance(module, LoraLayer):
|
||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||
module, fsdp2_kwargs
|
||||
)
|
||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||
fully_shard(module, **fsdp2_kwargs)
|
||||
|
||||
fully_shard(model, **fsdp2_kwargs)
|
||||
|
||||
if log_bias_dtype_mismatch:
|
||||
LOG.warning(
|
||||
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
fsdp2_load_full_state_dict(
|
||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# We re-register the buffers, as they may not be in the state_dict
|
||||
for fqn, buffer_tensor in original_non_persistent_buffers.items():
|
||||
buffer_tensor = buffer_tensor.to(accelerator.device)
|
||||
|
||||
if "." in fqn:
|
||||
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(parent_fqn)
|
||||
else:
|
||||
local_buffer_name = fqn
|
||||
parent_module = model
|
||||
|
||||
parent_module.register_buffer(
|
||||
local_buffer_name, buffer_tensor, persistent=False
|
||||
)
|
||||
|
||||
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
|
||||
# Needs to be called both here and above
|
||||
# removing this call makes the have slightly different loss
|
||||
# removing the call above leads to extra memory usage as explained in the comment above
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
|
||||
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
|
||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate"],
|
||||
|
||||
@@ -6,6 +6,10 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
# TODO remove this patch when transformers#37285 is merged and in a release
|
||||
@@ -46,10 +50,15 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||
self.training = training
|
||||
LOG.info(
|
||||
"Compiling flex attention with kwargs: %s. This may take a while...",
|
||||
flex_attn_compile_kwargs,
|
||||
)
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention,
|
||||
**flex_attn_compile_kwargs,
|
||||
)
|
||||
LOG.info("Flex attention compiled successfully.")
|
||||
self._is_flex_compiled = True
|
||||
|
||||
def __call__(self):
|
||||
|
||||
13
src/axolotl/monkeypatch/trainer/trl.py
Normal file
13
src/axolotl/monkeypatch/trainer/trl.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Monkeypatch for TRL trainer FSDP preparation."""
|
||||
|
||||
|
||||
def prepare_fsdp(model, accelerator):
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import fsdp2_prepare_model
|
||||
|
||||
return fsdp2_prepare_model(accelerator, model)
|
||||
|
||||
|
||||
def patch_trl_prepare_fsdp2():
|
||||
import trl.models.utils
|
||||
|
||||
trl.models.utils.prepare_fsdp = prepare_fsdp
|
||||
@@ -15,7 +15,6 @@ from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
from peft import PeftConfig, PeftModel
|
||||
@@ -68,7 +67,7 @@ def setup_model_and_tokenizer(
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
@@ -76,11 +75,8 @@ def setup_model_and_tokenizer(
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
# Load the model and peft_config
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
# Load the model
|
||||
LOG.debug("Loading model")
|
||||
|
||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||
model, peft_config = model_loader.load()
|
||||
@@ -264,15 +260,6 @@ def save_trained_model(
|
||||
"QAT modules have been converted for PTQ. Please ensure you quantize "
|
||||
"your model weights with `axolotl quantize`."
|
||||
)
|
||||
|
||||
# Handle FSDP state dict type
|
||||
state_dict_type = "FULL_STATE_DICT"
|
||||
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
|
||||
if cfg.fsdp_final_state_dict_type:
|
||||
state_dict_type = cfg.fsdp_final_state_dict_type
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
@@ -281,22 +268,19 @@ def save_trained_model(
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
if cfg.fsdp:
|
||||
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
|
||||
# processes attempt to write the same file
|
||||
if (
|
||||
state_dict_type == "SHARDED_STATE_DICT"
|
||||
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
||||
):
|
||||
save_fsdp_model(
|
||||
trainer.accelerator.state.fsdp_plugin,
|
||||
trainer.accelerator,
|
||||
trainer.model,
|
||||
cfg.output_dir,
|
||||
if trainer.is_fsdp_enabled:
|
||||
if cfg.fsdp_config or cfg.fsdp:
|
||||
if cfg.fsdp_config.final_state_dict_type:
|
||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||
else:
|
||||
state_dict_type = cfg.fsdp_config.state_dict_type
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
trainer.save_model(cfg.output_dir)
|
||||
if state_dict_type == "SHARDED_STATE_DICT":
|
||||
LOG.info(
|
||||
"The final model was saved with a sharded state dict. Please ensure you merge "
|
||||
"the sharded weights with `merge-sharded-fsdp-weights`."
|
||||
)
|
||||
elif state_dict_type == "FULL_STATE_DICT":
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
@@ -91,21 +92,27 @@ def gpu_memory_usage_smi(device=0):
|
||||
return 0.0
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
cur_device = get_device_type()
|
||||
def log_gpu_memory_usage(
|
||||
log: logging.Logger | logging.LoggerAdapter,
|
||||
msg: str = "",
|
||||
device: int | torch.device = 0,
|
||||
):
|
||||
cur_device_type = str(get_device_type())
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
elif "npu" in str(cur_device) and is_torch_npu_available():
|
||||
elif "npu" in cur_device_type and is_torch_npu_available():
|
||||
usage, cache, misc = npu_memory_usage_all(device)
|
||||
else:
|
||||
elif "gpu" in cur_device_type and torch.cuda.is_available():
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
else:
|
||||
return
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
msg = f"{cur_device_type} memory usage:" if not msg else msg
|
||||
log.info(
|
||||
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
|
||||
f"{msg} {usage:.03f}GB ({', '.join(extras)})",
|
||||
stacklevel=2,
|
||||
)
|
||||
return usage, cache, misc
|
||||
|
||||
@@ -115,10 +115,10 @@ def normalize_config(cfg):
|
||||
"chrf",
|
||||
]
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
if cfg.world_size != 1:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if not cfg.use_ray:
|
||||
# delay resolving dtype until on worker node when launching with ray
|
||||
@@ -313,3 +313,16 @@ def prepare_plugins(cfg):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
|
||||
# TODO @SalmanMohammadi remove this function in 0.12
|
||||
def migrate_fsdp_config(cfg):
|
||||
if cfg.get("fsdp_config"):
|
||||
fsdp_config_keys = cfg.fsdp_config.keys()
|
||||
if "fsdp_version" in fsdp_config_keys:
|
||||
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
|
||||
|
||||
for key in list(fsdp_config_keys):
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
|
||||
del cfg.fsdp_config[key]
|
||||
|
||||
@@ -203,7 +203,9 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
dataset_processes: int | None = Field(
|
||||
default=min(int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()), # type: ignore[type-var]
|
||||
default=min(
|
||||
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
|
||||
), # type: ignore[type-var]
|
||||
json_schema_extra={
|
||||
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
|
||||
},
|
||||
@@ -572,14 +574,24 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
fsdp: list[str] | None = Field(
|
||||
default=None, json_schema_extra={"description": "FSDP configuration"}
|
||||
default=None,
|
||||
json_schema_extra={"description": "FSDP configuration"},
|
||||
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
|
||||
)
|
||||
# TODO @SalmanMohammadi strongly type this as its own schema
|
||||
fsdp_config: dict[str, Any] | None = Field(
|
||||
default=None, json_schema_extra={"description": "FSDP configuration options"}
|
||||
)
|
||||
fsdp_version: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "FSDP version"},
|
||||
)
|
||||
fsdp_final_state_dict_type: (
|
||||
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
|
||||
) = None
|
||||
) = Field(
|
||||
default=None,
|
||||
deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.",
|
||||
)
|
||||
|
||||
val_set_size: float | None = Field(
|
||||
default=0.0,
|
||||
@@ -949,11 +961,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_fsdp2 = (
|
||||
data.get("fsdp_config") is not None
|
||||
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
|
||||
)
|
||||
is_fsdp = data.get("fsdp_config") is not None
|
||||
is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
@@ -987,11 +997,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_fsdp2 = (
|
||||
data.get("fsdp_config") is not None
|
||||
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
|
||||
)
|
||||
is_fsdp = data.get("fsdp_config") is not None
|
||||
is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
|
||||
|
||||
if (
|
||||
not is_multi_gpu
|
||||
@@ -1114,21 +1121,94 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and data.get("fsdp_config")
|
||||
and str(data["fsdp_config"].get("fsdp_version")) == "2"
|
||||
):
|
||||
if version.parse(torch_version) < version.parse("2.7.0"):
|
||||
raise ValueError(
|
||||
"FSDP2 and QAT are not supported on torch version < 2.7.0"
|
||||
)
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.6.0"):
|
||||
raise ValueError("QAT is not supported on torch version < 2.6.0")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_torch_version(cls, data):
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
|
||||
if version.parse(torch_version) < version.parse("2.7.0"):
|
||||
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version(cls, data):
|
||||
fsdp_config = data.get("fsdp_config", {})
|
||||
if fsdp_config and str(data.get("fsdp_version")) != "2":
|
||||
LOG.info(
|
||||
"FSDP1 will be deprecated in an upcoming release of Axolotl."
|
||||
"We recommend that you use FSDP version 2 for better performance and compatibility. "
|
||||
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
|
||||
"For more details on migrating your config. "
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
|
||||
fsdp_config = data.get("fsdp_config")
|
||||
if fsdp_config and data.get("fsdp_version") == 2:
|
||||
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
||||
):
|
||||
raise ValueError(
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
||||
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_base_model_quant_dpo(cls, data):
|
||||
if data.get("fsdp_version") == 2 and data.get("rl") in [
|
||||
RLType.DPO,
|
||||
RLType.KTO,
|
||||
RLType.ORPO,
|
||||
RLType.IPO,
|
||||
]:
|
||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||
raise ValueError(
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
if fsdp_config.get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def default_dataloader_opts(cls, data):
|
||||
|
||||
@@ -574,15 +574,6 @@ class LoRAValidationMixin:
|
||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def hint_lora_8bit(self):
|
||||
loftq = (
|
||||
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
|
||||
)
|
||||
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||
@@ -786,7 +777,7 @@ class OptimizationValidationMixin:
|
||||
@classmethod
|
||||
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
|
||||
if (
|
||||
data.get("fsdp")
|
||||
data.get("fsdp_config")
|
||||
and data.get("save_safetensors")
|
||||
and data.get("fsdp_config")
|
||||
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
||||
@@ -1000,7 +991,7 @@ class ComplexValidationMixin:
|
||||
if self.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if self.fsdp:
|
||||
if self.fsdp or self.fsdp_config:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
if self.deepspeed:
|
||||
|
||||
@@ -563,37 +563,39 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if str(cfg.fsdp_config.fsdp_version) == "2":
|
||||
|
||||
# TODO @SalmanMohammadi remove FSDP1 args in 0.12
|
||||
if str(cfg.fsdp_version) == "2":
|
||||
os.environ["FSDP_VERSION"] = "2"
|
||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||
if cfg.fsdp_config.activation_checkpointing:
|
||||
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
|
||||
if cfg.fsdp_config.fsdp_offload_params:
|
||||
if cfg.fsdp_config.offload_params:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||
if cfg.fsdp_config.sync_module_states:
|
||||
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
||||
if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
if cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
|
||||
if cfg.fsdp_config.fsdp_use_orig_params:
|
||||
if cfg.fsdp_config.use_orig_params:
|
||||
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
|
||||
if cfg.fsdp_config.fsdp_state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
||||
if cfg.fsdp_config.fsdp_auto_wrap_policy:
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy
|
||||
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
||||
if cfg.fsdp_config.state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
|
||||
if cfg.fsdp_config.auto_wrap_policy:
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
|
||||
if cfg.fsdp_config.transformer_layer_cls_to_wrap:
|
||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
|
||||
cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||
)
|
||||
if cfg.fsdp_config.fsdp_reshard_after_forward is not None:
|
||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = (
|
||||
"true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false"
|
||||
cfg.fsdp_config.transformer_layer_cls_to_wrap
|
||||
)
|
||||
if cfg.fsdp_config.reshard_after_forward:
|
||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
if not check_cuda_p2p_ib_support():
|
||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||
if cfg.fsdp:
|
||||
# TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
|
||||
if cfg.fsdp or cfg.fsdp_config:
|
||||
cfg.fsdp = True if not cfg.fsdp else cfg.fsdp
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
stage = None
|
||||
@@ -657,11 +659,7 @@ def setup_trainer(
|
||||
"""
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
if (
|
||||
cfg.torch_compile
|
||||
and cfg.fsdp_config
|
||||
and str(cfg.fsdp_config.fsdp_version) == "2"
|
||||
):
|
||||
if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
|
||||
patch_evaluation_loop_for_fsdp2()
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
|
||||
Reference in New Issue
Block a user