update fsdp2 patch
This commit is contained in:
@@ -254,9 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": accelerator.state.device_mesh[
|
||||
accelerator.state.parallelism_config.model_shard_dim_names
|
||||
],
|
||||
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
|
||||
0
src/axolotl/utils/schemas/distributed.py
Normal file
0
src/axolotl/utils/schemas/distributed.py
Normal file
Reference in New Issue
Block a user