FSDP1 -> FSDP2 (#2760)

* FSDP2 args migration implementation

This commit implements the migration to FSDP2 arguments including:
- FSDP2 support with LoRA training
- DPO integration with FSDP2
- Model loading fixes and refactoring
- CPU offloading and PEFT handling
- Test updates and CI improvements
- Bug fixes for dtype errors and various edge cases
This commit is contained in:
salman
2025-07-12 15:18:01 +01:00
committed by GitHub
parent eb662557a7
commit d6e4a611e5
27 changed files with 1357 additions and 436 deletions

View File

@@ -23,8 +23,6 @@ Axolotl supports several methods for multi-GPU training:
## DeepSpeed {#sec-deepspeed} ## DeepSpeed {#sec-deepspeed}
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
### Configuration {#sec-deepspeed-config} ### Configuration {#sec-deepspeed-config}
Add to your YAML config: Add to your YAML config:
@@ -32,7 +30,6 @@ Add to your YAML config:
```{.yaml} ```{.yaml}
deepspeed: deepspeed_configs/zero1.json deepspeed: deepspeed_configs/zero1.json
``` ```
### Usage {#sec-deepspeed-usage} ### Usage {#sec-deepspeed-usage}
```{.bash} ```{.bash}
@@ -75,9 +72,66 @@ ZeRO Stage 3 can be used for training on a single GPU by manually setting the en
::: :::
## FSDP {#sec-fsdp} ## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
### Basic FSDP Configuration {#sec-fsdp-config} ::: {.callout-note}
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
:::
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
also follow the config field mapping below to update field names.
#### Config mapping
FSDP1 | FSDP2
-------- | --------
fsdp_sharding_strategy | reshard_after_forward
fsdp_backward_prefetch_policy | **REMOVED**
fsdp_backward_prefetch | **REMOVED**
fsdp_forward_prefetch | **REMOVED**
fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
For example, if you were using the following FSDP1 config:
```{.yaml}
fsdp_version: 1
fsdp_config:
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
```
You can migrate to the following FSDP2 config:
```{.yaml}
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
```
### FSDP1 (deprecated) {#sec-fsdp-config}
::: {.callout-note}
Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead.
:::
```{.yaml} ```{.yaml}
fsdp: fsdp:
@@ -89,6 +143,7 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```
## Sequence parallelism {#sec-sequence-parallelism} ## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the We support sequence parallelism (SP) via the

View File

@@ -40,13 +40,13 @@ use_cpu: false
Configure your model to use FSDP in the Axolotl yaml. For example: Configure your model to use FSDP in the Axolotl yaml. For example:
```yaml ```yaml
fsdp: fsdp_version: 2
- full_shard
- auto_wrap
fsdp_config: fsdp_config:
fsdp_offload_params: true offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
``` ```
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine. All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.

View File

@@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto) - [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo) - [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
## RLHF using Axolotl ## RLHF using Axolotl

View File

@@ -16,6 +16,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
validate_config, validate_config,
@@ -226,6 +227,7 @@ def load_cfg(
}, },
) )
migrate_fsdp_config(cfg)
prepare_optim_env(cfg) prepare_optim_env(cfg)
prepare_opinionated_env(cfg) prepare_opinionated_env(cfg)
normalize_config(cfg) normalize_config(cfg)

View File

@@ -501,6 +501,10 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.reward_model or self.cfg.rl: if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.fsdp_config or self.cfg.fsdp:
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
self._configure_reporting(training_args_kwargs) self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs) self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs) self._configure_scheduler(training_args_kwargs)

View File

@@ -151,14 +151,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args( training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps total_num_steps
) )
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora": if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True training_arguments_kwargs["qlora"] = True

View File

@@ -208,7 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**trainer_kwargs, **trainer_kwargs,
) )
if self.cfg.fsdp: if self.cfg.fsdp_config or self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype) ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model: if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype) ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
@@ -218,21 +218,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer.add_callback(callback) trainer.add_callback(callback)
return trainer return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -14,5 +14,4 @@ from .trl import (
AxolotlORPOTrainer, AxolotlORPOTrainer,
AxolotlPRMTrainer, AxolotlPRMTrainer,
AxolotlRewardTrainer, AxolotlRewardTrainer,
TRLPPOTrainer,
) )

View File

@@ -1,12 +1,9 @@
"""Module for TRL PPO trainer""" """Module for TRL RL trainers"""
import torch
from tqdm import tqdm
from trl import ( from trl import (
CPOTrainer, CPOTrainer,
KTOTrainer, KTOTrainer,
ORPOTrainer, ORPOTrainer,
PPOTrainer,
PRMTrainer, PRMTrainer,
RewardTrainer, RewardTrainer,
) )
@@ -16,64 +13,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 32,
}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
}
for _, batch in tqdm(enumerate(self.dataloader)):
query_tensors = batch["input_ids"]
# generate model response
response_tensors, ref_response_tensors = self.generate(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs,
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
ref_rewards = [
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
]
batch["ref_rewards"] = ref_rewards
# Run PPO step
stats = self.step(query_tensors, response_tensors, rewards)
self.log_stats(
stats,
batch,
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
class AxolotlORPOTrainer( class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
): ):

View File

@@ -122,9 +122,9 @@ def load_lora(
rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("LOCAL_RANK", 0))
if ( if (
cfg.fsdp cfg.fsdp_config
and cfg.adapter and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0 and rank != 0
): ):
setup_quantized_meta_for_peft(model) setup_quantized_meta_for_peft(model)
@@ -152,9 +152,9 @@ def load_lora(
"Exception caught during model.print_trainable_parameters(): %s", exc "Exception caught during model.print_trainable_parameters(): %s", exc
) )
elif ( elif (
cfg.fsdp cfg.fsdp_config
and cfg.adapter and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0 and rank != 0
): ):
setup_quantized_peft_meta_for_training(model) setup_quantized_peft_meta_for_training(model)

View File

@@ -140,10 +140,15 @@ class ModelLoader:
"""Check if flash attention is installed.""" """Check if flash attention is installed."""
return find_spec("flash_attn") is not None return find_spec("flash_attn") is not None
@cached_property @property
def qlora_fsdp(self): def is_fsdp_enabled(self):
"""Property that determines if FSDP is enabled."""
return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None
@property
def is_qlora_and_fsdp_enabled(self):
"""Property that determines if FSDP with QLoRA is enabled.""" """Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora" return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches. """Load and prepare the model with all configurations and patches.
@@ -189,15 +194,15 @@ class ModelLoader:
# Handle PeftModel if needed # Handle PeftModel if needed
if ( if (
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM)) isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
and not self.qlora_fsdp and not self.is_qlora_and_fsdp_enabled
): ):
self.model = self.model.merge_and_unload() self.model = self.model.merge_and_unload()
self._resize_token_embeddings() self._resize_token_embeddings()
self._adjust_model_config() self._adjust_model_config()
self._log_memory_usage()
self._configure_embedding_dtypes() self._configure_embedding_dtypes()
self._configure_qat() self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _resize_token_embeddings(self): def _resize_token_embeddings(self):
"""Resize token embeddings if needed.""" """Resize token embeddings if needed."""
@@ -251,22 +256,13 @@ class ModelLoader:
): ):
self.model.config.eos_token_id = self.tokenizer.eos_token_id self.model.config.eos_token_id = self.tokenizer.eos_token_id
def _log_memory_usage(self):
"""Log device memory usage after model load."""
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
def _configure_embedding_dtypes(self): def _configure_embedding_dtypes(self):
"""Configure embedding module dtypes.""" """Configure embedding module dtypes."""
# Get embedding modules # Get embedding modules
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
# Initial dtype conversion # Initial dtype conversion
if not self.cfg.fsdp: if not self.is_fsdp_enabled:
# We don't run this during FSDP because this will leave mixed and bfloat16 # We don't run this during FSDP because this will leave mixed and bfloat16
# dtypes in the model which FSDP doesn't like # dtypes in the model which FSDP doesn't like
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
@@ -282,7 +278,7 @@ class ModelLoader:
self._set_z3_leaf_modules() self._set_z3_leaf_modules()
# Apply gradient checkpointing if needed # Apply gradient checkpointing if needed
needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled
if self.cfg.adapter in ["lora", "qlora"]: if self.cfg.adapter in ["lora", "qlora"]:
needs_fa2_dtype = True needs_fa2_dtype = True
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
@@ -298,10 +294,12 @@ class ModelLoader:
# we need to convert them back to fp16/bf16 for flash-attn compatibility. # we need to convert them back to fp16/bf16 for flash-attn compatibility.
( (
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
and not self.qlora_fsdp and not self.is_qlora_and_fsdp_enabled
)
or (
# CCE requires embedding layers to be in fp16/bf16 for backward pass
self.cfg.cut_cross_entropy
) )
# CCE requires embedding layers to be in fp16/bf16 for backward pass
or self.cfg.cut_cross_entropy
) )
if should_convert: if should_convert:
@@ -357,7 +355,6 @@ class ModelLoader:
and not (self.cfg.rl and self.cfg.load_in_4bit) and not (self.cfg.rl and self.cfg.load_in_4bit)
and not skip_move_to_device and not skip_move_to_device
): ):
# TODO: validate this conditional
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
@@ -430,7 +427,17 @@ class ModelLoader:
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
if not is_deepspeed_zero3_enabled(): is_ds_zero3 = is_deepspeed_zero3_enabled()
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
if self.is_fsdp_enabled:
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
if self.is_qlora_and_fsdp_enabled:
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
# For other FSDP cases, don't set device_map at all
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map self.model_kwargs["device_map"] = device_map
cur_device = get_device_type() cur_device = get_device_type()
@@ -499,7 +506,7 @@ class ModelLoader:
"bnb_4bit_quant_storage": torch.bfloat16, "bnb_4bit_quant_storage": torch.bfloat16,
} }
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.cfg.fsdp self.cfg.deepspeed or self.is_fsdp_enabled
): ):
# for some reason, this causes the loss to be off by an order of magnitude # for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
@@ -604,9 +611,21 @@ class ModelLoader:
def _build_model(self) -> bool: def _build_model(self) -> bool:
"""Load model, with load strategy depending on config.""" """Load model, with load strategy depending on config."""
skip_move_to_device = False skip_move_to_device = False
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
):
del self.model_kwargs["device_map"]
elif self.is_qlora_and_fsdp_enabled:
skip_move_to_device = True
if ( if (
self.qlora_fsdp self.is_qlora_and_fsdp_enabled
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and self.cfg.fsdp_config.cpu_ram_efficient_loading
and ( and (
self.cfg.model_config_type == "dbrx" self.cfg.model_config_type == "dbrx"
or self.cfg.qlora_sharded_model_loading or self.cfg.qlora_sharded_model_loading
@@ -632,12 +651,6 @@ class ModelLoader:
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
and not self.cfg.gptq and not self.cfg.gptq
): ):
# TODO: Do we need to open this up for all models?
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
# Please don't remove underscore binding without reading the fn docstring. # Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading() _ = self._configure_zero3_memory_efficient_loading()
@@ -691,33 +704,22 @@ class ModelLoader:
trust_remote_code=self.cfg.trust_remote_code or False, trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs, **self.model_kwargs,
) )
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else: else:
if self.cfg.gptq: # Please don't remove underscore binding without reading the fn docstring.
self.model = self.auto_model_loader.from_pretrained( _ = self._configure_zero3_memory_efficient_loading()
self.base_model, self.model = self.auto_model_loader.from_pretrained(
config=self.model_config, self.base_model,
trust_remote_code=self.cfg.trust_remote_code or False, config=self.model_config,
**self.model_kwargs, trust_remote_code=self.cfg.trust_remote_code or False,
) **self.model_kwargs,
else: )
if (
self.cfg.fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
):
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
skip_move_to_device = True skip_move_to_device = True
@@ -753,8 +755,8 @@ class ModelLoader:
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if ( if (
self.qlora_fsdp self.is_qlora_and_fsdp_enabled
or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading)
or is_deepspeed_zero3_enabled() or is_deepspeed_zero3_enabled()
): ):
# Make sure everything is in the same dtype # Make sure everything is in the same dtype

View File

@@ -94,10 +94,14 @@ class PatchManager:
def _apply_fsdp_patches(self): def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations.""" """Apply patches for FSDP configurations."""
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2 from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp2() patch_accelerate_fsdp2()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
patch_trl_prepare_fsdp2()
# if self.cfg.fsdp_config: # if self.cfg.fsdp_config:
# # see transformers#39152 # # see transformers#39152

View File

@@ -195,9 +195,11 @@ def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16):
bias_mismatch = module.bias.dtype != dtype bias_mismatch = module.bias.dtype != dtype
if weight_mismatch: if weight_mismatch:
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") LOG.debug(
f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}"
)
if bias_mismatch: if bias_mismatch:
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") LOG.debug(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
if weight_mismatch or bias_mismatch: if weight_mismatch or bias_mismatch:
module.to(dtype) module.to(dtype)

View File

@@ -2,102 +2,65 @@
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
""" """
import copy
import functools
import sys import sys
import torch import torch
from torch import nn
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): def fsdp2_load_full_state_dict(
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
):
""" """
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place. parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args: Args:
accelerator (`Accelerator`): The accelerator instance accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): model (`torch.nn.Module`):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can only be on rank 0 full_sd (`dict`): The full state dict to load, can only be on rank 0
""" """
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor from torch.distributed.tensor import distribute_tensor
# Model was previously copied to meta device LOG.info("Broadcasting full state dict to all ranks...")
import time
start_time = time.time()
meta_sharded_sd = model.state_dict() meta_sharded_sd = model.state_dict()
sharded_sd = {} sharded_sd = {}
for param_name, full_tensor in full_sd.items():
# Rank 0 distributes the full state dict to other ranks sharded_meta_param = meta_sharded_sd.get(param_name)
def _infer_parameter_dtype(model, param_name, empty_param): full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
try: if hasattr(sharded_meta_param, "device_mesh"):
old_param = model.get_parameter_or_buffer(param_name) sharded_param = distribute_tensor(
except AttributeError:
# Need this for LORA, as there some params are not *parameters* of sorts
base_param_name, local_param_name = param_name.rsplit(".", 1)
submodule = model.get_submodule(base_param_name)
old_param = getattr(submodule, local_param_name)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
casting_dtype = None
is_param_float8_e4m3fn = (
is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
)
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype
def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor
param_names = sorted(meta_sharded_sd.keys())
for param_name in param_names:
mesh = meta_sharded_sd[param_name].device_mesh
if accelerator.is_main_process:
full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements
)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_param,
)
sharded_tensor = _cast_and_contiguous(
sharded_tensor, to_contiguous, casting_dtype
)
sharded_sd[param_name] = sharded_tensor
else:
full_tensor = torch.empty(
sharded_sd[param_name].size(),
device="cuda",
dtype=sharded_sd[param_name].dtype,
)
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements
)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_tensor, full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
src_data_rank=0,
) )
sharded_tensor = _cast_and_contiguous( else:
sharded_tensor, to_contiguous, casting_dtype sharded_param = full_tensor
)
sharded_sd[param_name] = sharded_tensor
# we set `assign=True` because our params are on meta device if offload_to_cpu:
model.load_state_dict(sharded_sd, assign=True) sharded_param = sharded_param.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_param)
del full_tensor
full_sd[param_name] = None
model.load_state_dict(sharded_sd, assign=True, strict=True)
end_time = time.time()
LOG.debug(
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
)
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
return model return model
@@ -191,17 +154,195 @@ def get_state_dict(self, model, unwrap=True):
return state_dict return state_dict
def patch_accelerate_fsdp2(): def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
import accelerate """Helper function to process LoRA modules for FSDP2."""
from accelerate.utils import fsdp_utils from torch.distributed.fsdp import fully_shard
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict log_bias_dtype_mismatch = False
setattr(
sys.modules["accelerate.utils.fsdp_utils"], # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
"fsdp2_load_full_state_dict", # wrap this. Therefore we must ensure the bias has the same dtype as the weight
fsdp2_load_full_state_dict, if module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to prepare
Returns:
`torch.nn.Module`: Prepared model
"""
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
from accelerate.utils.modeling import get_non_persistent_buffers
from peft import PeftModel
from peft.tuners.lora import LoraLayer
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
) )
is_type_fsdp = isinstance(model, FSDPModule) or (
is_compiled_module(model)
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
)
if is_type_fsdp:
return model
fsdp2_plugin = accelerator.state.fsdp_plugin
original_sd = model.state_dict()
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
pass # auto_wrap_policy_type = "transformer"
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
pass # auto_wrap_policy_type = "size"
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
# This is because of `apply_activation_checkpointing` which will can reuse this function
fsdp2_plugin.set_auto_wrap_policy(model)
if fsdp2_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
# Apply activation checkpointing before applying `fully_shard`
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
}
model_has_params4bit = False
for _, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
if param.__class__.__name__ == "Params4bit":
model_has_params4bit = True
break
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
# Also, these buffers aren't getting sharded by default
# We get the FQNs of all non-persistent buffers, to re-register them after
non_persistent_buffer_fqns = get_non_persistent_buffers(
model, recurse=True, fqns=True
)
original_non_persistent_buffers = copy.deepcopy(
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
)
# We move the model to meta device, as then sharding happens on meta device
model = model.to(torch.device("meta"))
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
# We assume `transformers` models have a `tie_weights` method if they support it
if hasattr(model, "tie_weights"):
model.tie_weights()
is_peft_model = isinstance(model, PeftModel)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
fully_shard(model, **fsdp2_kwargs)
if log_bias_dtype_mismatch:
LOG.warning(
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
)
if fsdp2_plugin.cpu_ram_efficient_loading:
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
fsdp2_load_full_state_dict(
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():
buffer_tensor = buffer_tensor.to(accelerator.device)
if "." in fqn:
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
parent_module = model.get_submodule(parent_fqn)
else:
local_buffer_name = fqn
parent_module = model
parent_module.register_buffer(
local_buffer_name, buffer_tensor, persistent=False
)
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
# Needs to be called both here and above
# removing this call makes the have slightly different loss
# removing the call above leads to extra memory usage as explained in the comment above
if hasattr(model, "tie_weights"):
model.tie_weights()
return model
def patch_accelerate_fsdp2():
import accelerate
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
accelerate.Accelerator.get_state_dict = get_state_dict accelerate.Accelerator.get_state_dict = get_state_dict
setattr( setattr(
sys.modules["accelerate"], sys.modules["accelerate"],

View File

@@ -6,6 +6,10 @@ from typing import Optional, Tuple, Union
import torch import torch
import transformers import transformers
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_flex_wrapper(**flex_attn_compile_kwargs): def patch_flex_wrapper(**flex_attn_compile_kwargs):
# TODO remove this patch when transformers#37285 is merged and in a release # TODO remove this patch when transformers#37285 is merged and in a release
@@ -46,10 +50,15 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training # see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training self.training = training
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
)
self._compiled_flex_attention = torch.compile( self._compiled_flex_attention = torch.compile(
flex_attention, flex_attention,
**flex_attn_compile_kwargs, **flex_attn_compile_kwargs,
) )
LOG.info("Flex attention compiled successfully.")
self._is_flex_compiled = True self._is_flex_compiled = True
def __call__(self): def __call__(self):

View File

@@ -0,0 +1,13 @@
"""Monkeypatch for TRL trainer FSDP preparation."""
def prepare_fsdp(model, accelerator):
from axolotl.monkeypatch.accelerate.fsdp2 import fsdp2_prepare_model
return fsdp2_prepare_model(accelerator, model)
def patch_trl_prepare_fsdp2():
import trl.models.utils
trl.models.utils.prepare_fsdp = prepare_fsdp

View File

@@ -15,7 +15,6 @@ from typing import Any, Dict
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
@@ -68,7 +67,7 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`). `None`), and processor (if multimodal, else `None`).
""" """
# Load tokenizer # Load tokenizer
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed # Load processor for multimodal models if needed
@@ -76,11 +75,8 @@ def setup_model_and_tokenizer(
if cfg.is_multimodal: if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer) processor = load_processor(cfg, tokenizer)
# Load the model and peft_config # Load the model
msg = "loading model" LOG.debug("Loading model")
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model_loader = ModelLoader(cfg, tokenizer, processor=processor) model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load() model, peft_config = model_loader.load()
@@ -264,15 +260,6 @@ def save_trained_model(
"QAT modules have been converted for PTQ. Please ensure you quantize " "QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`." "your model weights with `axolotl quantize`."
) )
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
if cfg.fsdp_final_state_dict_type:
state_dict_type = cfg.fsdp_final_state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
# Handle ReLoRA early return case # Handle ReLoRA early return case
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
@@ -281,22 +268,19 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return return
if cfg.fsdp: if trainer.is_fsdp_enabled:
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading if cfg.fsdp_config or cfg.fsdp:
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple if cfg.fsdp_config.final_state_dict_type:
# processes attempt to write the same file state_dict_type = cfg.fsdp_config.final_state_dict_type
if ( else:
state_dict_type == "SHARDED_STATE_DICT" state_dict_type = cfg.fsdp_config.state_dict_type
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
): trainer.save_model(cfg.output_dir)
save_fsdp_model( if state_dict_type == "SHARDED_STATE_DICT":
trainer.accelerator.state.fsdp_plugin, LOG.info(
trainer.accelerator, "The final model was saved with a sharded state dict. Please ensure you merge "
trainer.model, "the sharded weights with `merge-sharded-fsdp-weights`."
cfg.output_dir,
) )
elif state_dict_type == "FULL_STATE_DICT":
trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled(): elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone() trainer.accelerator.wait_for_everyone()

View File

@@ -1,6 +1,7 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import logging
import torch import torch
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
@@ -91,21 +92,27 @@ def gpu_memory_usage_smi(device=0):
return 0.0 return 0.0
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(
cur_device = get_device_type() log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
cur_device_type = str(get_device_type())
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all() usage, cache, misc = mps_memory_usage_all()
elif "npu" in str(cur_device) and is_torch_npu_available(): elif "npu" in cur_device_type and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device) usage, cache, misc = npu_memory_usage_all(device)
else: elif "gpu" in cur_device_type and torch.cuda.is_available():
usage, cache, misc = gpu_memory_usage_all(device) usage, cache, misc = gpu_memory_usage_all(device)
else:
return
extras = [] extras = []
if cache > 0: if cache > 0:
extras.append(f"+{cache:.03f}GB cache") extras.append(f"+{cache:.03f}GB cache")
if misc > 0: if misc > 0:
extras.append(f"+{misc:.03f}GB misc") extras.append(f"+{misc:.03f}GB misc")
msg = f"{cur_device_type} memory usage:" if not msg else msg
log.info( log.info(
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", f"{msg} {usage:.03f}GB ({', '.join(extras)})",
stacklevel=2, stacklevel=2,
) )
return usage, cache, misc

View File

@@ -115,10 +115,10 @@ def normalize_config(cfg):
"chrf", "chrf",
] ]
choose_device(cfg) choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.world_size != 1:
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
cfg.batch_size = cfg.batch_size * cfg.world_size
if not cfg.use_ray: if not cfg.use_ray:
# delay resolving dtype until on worker node when launching with ray # delay resolving dtype until on worker node when launching with ray
@@ -313,3 +313,16 @@ def prepare_plugins(cfg):
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]: for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name) plugin_manager.register(plugin_name)
# TODO @SalmanMohammadi remove this function in 0.12
def migrate_fsdp_config(cfg):
if cfg.get("fsdp_config"):
fsdp_config_keys = cfg.fsdp_config.keys()
if "fsdp_version" in fsdp_config_keys:
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
for key in list(fsdp_config_keys):
if key.startswith("fsdp_") and key != "fsdp_version":
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
del cfg.fsdp_config[key]

View File

@@ -203,7 +203,9 @@ class AxolotlInputConfig(
}, },
) )
dataset_processes: int | None = Field( dataset_processes: int | None = Field(
default=min(int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()), # type: ignore[type-var] default=min(
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
), # type: ignore[type-var]
json_schema_extra={ json_schema_extra={
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set." "description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
}, },
@@ -572,14 +574,24 @@ class AxolotlInputConfig(
}, },
) )
fsdp: list[str] | None = Field( fsdp: list[str] | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration"} default=None,
json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
) )
# TODO @SalmanMohammadi strongly type this as its own schema
fsdp_config: dict[str, Any] | None = Field( fsdp_config: dict[str, Any] | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"} default=None, json_schema_extra={"description": "FSDP configuration options"}
) )
fsdp_version: int | None = Field(
default=None,
json_schema_extra={"description": "FSDP version"},
)
fsdp_final_state_dict_type: ( fsdp_final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = None ) = Field(
default=None,
deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.",
)
val_set_size: float | None = Field( val_set_size: float | None = Field(
default=0.0, default=0.0,
@@ -949,11 +961,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
or data.get("lora_o_kernel") or data.get("lora_o_kernel")
): ):
capabilities = data.get("capabilities") capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None is_fsdp = data.get("fsdp_config") is not None
is_fsdp2 = ( is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2: if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
if is_fsdp: if is_fsdp:
raise ValueError( raise ValueError(
@@ -987,11 +997,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
# Check multi-GPU compatibility # Check multi-GPU compatibility
capabilities = data.get("capabilities") capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1 is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
is_fsdp = data.get("fsdp") is not None is_fsdp = data.get("fsdp_config") is not None
is_fsdp2 = ( is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if ( if (
not is_multi_gpu not is_multi_gpu
@@ -1114,21 +1121,94 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
torch_version = str(torch.__version__).split("+", maxsplit=1)[0] torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if (
data.get("fsdp")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
):
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError(
"FSDP2 and QAT are not supported on torch version < 2.7.0"
)
if version.parse(torch_version) < version.parse("2.6.0"): if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError("QAT is not supported on torch version < 2.6.0") raise ValueError("QAT is not supported on torch version < 2.6.0")
return data return data
@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version(cls, data):
fsdp_config = data.get("fsdp_config", {})
if fsdp_config and str(data.get("fsdp_version")) != "2":
LOG.info(
"FSDP1 will be deprecated in an upcoming release of Axolotl."
"We recommend that you use FSDP version 2 for better performance and compatibility. "
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
"For more details on migrating your config. "
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
fsdp_config = data.get("fsdp_config")
if fsdp_config and data.get("fsdp_version") == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_dpo(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
if fsdp_config := data.get("fsdp_config"):
if fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def default_dataloader_opts(cls, data): def default_dataloader_opts(cls, data):

View File

@@ -574,15 +574,6 @@ class LoRAValidationMixin:
raise ValueError("Fused modules are not supported with LoRA/QLoRA") raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self return self
@model_validator(mode="after")
def hint_lora_8bit(self):
loftq = (
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
)
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def warn_qlora_zero3_w_use_reentrant(cls, data): def warn_qlora_zero3_w_use_reentrant(cls, data):
@@ -786,7 +777,7 @@ class OptimizationValidationMixin:
@classmethod @classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data): def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
if ( if (
data.get("fsdp") data.get("fsdp_config")
and data.get("save_safetensors") and data.get("save_safetensors")
and data.get("fsdp_config") and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
@@ -1000,7 +991,7 @@ class ComplexValidationMixin:
if self.adapter not in ("lora", "qlora"): if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if self.fsdp: if self.fsdp or self.fsdp_config:
raise ValueError("fsdp not supported with ReLoRA") raise ValueError("fsdp not supported with ReLoRA")
if self.deepspeed: if self.deepspeed:

View File

@@ -563,37 +563,39 @@ def setup_deepspeed_env(cfg, stage=None):
def setup_fsdp_envs(cfg): def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true" os.environ["ACCELERATE_USE_FSDP"] = "true"
if str(cfg.fsdp_config.fsdp_version) == "2":
# TODO @SalmanMohammadi remove FSDP1 args in 0.12
if str(cfg.fsdp_version) == "2":
os.environ["FSDP_VERSION"] = "2" os.environ["FSDP_VERSION"] = "2"
if cfg.fsdp_config.fsdp_activation_checkpointing: if cfg.fsdp_config.activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true" os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params: if cfg.fsdp_config.offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true" os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states: if cfg.fsdp_config.sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true" os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if cfg.fsdp_config.cpu_ram_efficient_loading:
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true" os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
if cfg.fsdp_config.fsdp_use_orig_params: if cfg.fsdp_config.use_orig_params:
os.environ["FSDP_USE_ORIG_PARAMS"] = "true" os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_state_dict_type: if cfg.fsdp_config.state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
if cfg.fsdp_config.fsdp_auto_wrap_policy: if cfg.fsdp_config.auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: if cfg.fsdp_config.transformer_layer_cls_to_wrap:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ( os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap cfg.fsdp_config.transformer_layer_cls_to_wrap
)
if cfg.fsdp_config.fsdp_reshard_after_forward is not None:
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = (
"true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false"
) )
if cfg.fsdp_config.reshard_after_forward:
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
def prepare_optim_env(cfg): def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support(): if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None: if os.getenv("NCCL_P2P_DISABLE") is None:
os.environ["NCCL_P2P_DISABLE"] = "1" os.environ["NCCL_P2P_DISABLE"] = "1"
if cfg.fsdp: # TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
if cfg.fsdp or cfg.fsdp_config:
cfg.fsdp = True if not cfg.fsdp else cfg.fsdp
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
stage = None stage = None
@@ -657,11 +659,7 @@ def setup_trainer(
""" """
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if ( if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
cfg.torch_compile
and cfg.fsdp_config
and str(cfg.fsdp_config.fsdp_version) == "2"
):
patch_evaluation_loop_for_fsdp2() patch_evaluation_loop_for_fsdp2()
if cfg.rl: if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)

View File

@@ -0,0 +1,326 @@
"""Test module for FSDP1 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
class TestFSDP1:
"""Test class for FSDP1 functionality."""
@pytest.mark.parametrize(
"fsdp_cpu_ram_efficient_loading",
[True, False],
)
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@pytest.mark.parametrize(
"adapter_config",
[
{
"adapter": "lora",
"load_in_4bit": False,
},
{
"adapter": "qlora",
"load_in_4bit": True,
},
],
)
def test_lora_sft(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"adapter": adapter_config["adapter"],
"load_in_4bit": adapter_config["load_in_4bit"],
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@pytest.mark.parametrize(
"adapter_config",
[
{
"adapter": "lora",
"load_in_4bit": False,
},
{
"adapter": "qlora",
"load_in_4bit": True,
},
],
)
def test_dpo_lora(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"load_in_4bit": adapter_config["load_in_4bit"],
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,
"adapter": adapter_config["adapter"],
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.01,
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"bf16": "auto",
"tf32": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)

View File

@@ -0,0 +1,355 @@
"""Test module for FSDP2 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
class TestFSDP2:
"""Test class for FSDP2 functionality."""
@require_torch_2_7_0
@pytest.mark.parametrize(
"fsdp_cpu_ram_efficient_loading",
[True, False],
)
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
@pytest.mark.parametrize("peft_use_dora", [True, False])
def test_lora_sft(self, temp_dir, peft_use_dora):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"peft_use_dora": peft_use_dora,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_qlora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_dpo_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)

View File

@@ -1,93 +0,0 @@
"""
E2E tests for multigpu qwen2
"""
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
class TestMultiGPUQwen2:
"""
Test case for Llama models using LoRA
"""
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": base_model,
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.01,
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"warmup_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
"tf32": True,
# "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"fsdp": [
"full_shard",
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
},
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -77,6 +77,18 @@ def require_torch_2_6_0(test_case):
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case) return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
def require_torch_2_7_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_7_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.7.0")
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def require_torch_lt_2_6_0(test_case): def require_torch_lt_2_6_0(test_case):
""" """
Decorator marking a test that requires torch < 2.6.0 Decorator marking a test that requires torch < 2.6.0

View File

@@ -5,7 +5,11 @@ Test classes for checking functionality of the cfg normalization
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from axolotl.utils.config import normalize_cfg_datasets, normalize_config from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -90,3 +94,104 @@ class NormalizeConfigTestCase(unittest.TestCase):
self.assertTrue(cfg.bf16) self.assertTrue(cfg.bf16)
self.assertFalse(cfg.fp16) self.assertFalse(cfg.fp16)
def test_migrate_fsdp_config(self):
"""Test basic FSDP config migration with and without fsdp_version"""
cfg_with_version = DictDefault(
{
"fsdp_config": {
"fsdp_version": 2,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"regular_param": "value",
}
}
)
migrate_fsdp_config(cfg_with_version)
self.assertEqual(cfg_with_version.fsdp_version, 2)
self.assertEqual(
cfg_with_version.fsdp_config.auto_wrap_policy, "TRANSFORMER_BASED_WRAP"
)
self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)
self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)
self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value")
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_cpu_ram_efficient_loading", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
self.assertNotIn("version", cfg_with_version.fsdp_config)
cfg_without_version = DictDefault(
{
"fsdp_config": {
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
"fsdp_offload_params": True,
"regular_param": "value",
}
}
)
migrate_fsdp_config(cfg_without_version)
self.assertNotIn("fsdp_version", cfg_without_version)
self.assertEqual(
cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP"
)
self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)
self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value")
self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)
def test_migrate_fsdp_config_no_fsdp_config(self):
"""Test that function doesn't crash when no fsdp_config is present"""
cfg = DictDefault({"some_other_config": "value"})
migrate_fsdp_config(cfg)
self.assertNotIn("fsdp_config", cfg)
self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.some_other_config, "value")
def test_migrate_fsdp_config_empty_fsdp_config(self):
"""Test migration with empty fsdp_config"""
cfg = DictDefault({"fsdp_config": {}})
migrate_fsdp_config(cfg)
self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.fsdp_config, {})
def test_migrate_fsdp_config_mixed_keys(self):
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
cfg = DictDefault(
{
"fsdp_config": {
"fsdp_version": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"mixed_precision_policy": "fp16",
"activation_checkpointing": True,
"fsdp_reshard_after_forward": False,
}
}
)
migrate_fsdp_config(cfg)
self.assertEqual(cfg.fsdp_version, 1)
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
self.assertEqual(cfg.fsdp_config.reshard_after_forward, False)
self.assertEqual(cfg.fsdp_config.mixed_precision_policy, "fp16")
self.assertEqual(cfg.fsdp_config.activation_checkpointing, True)
# Check original fsdp_ keys are removed
self.assertNotIn("fsdp_version", cfg.fsdp_config)
self.assertNotIn("fsdp_state_dict_type", cfg.fsdp_config)
self.assertNotIn("fsdp_reshard_after_forward", cfg.fsdp_config)
# Ensure no duplicate version key
self.assertNotIn("version", cfg.fsdp_config)