update fsdp2 patch
This commit is contained in:
@@ -254,9 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||||
"mesh": accelerator.state.device_mesh[
|
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
|
||||||
accelerator.state.parallelism_config.model_shard_dim_names
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model_has_params4bit = False
|
model_has_params4bit = False
|
||||||
|
|||||||
0
src/axolotl/utils/schemas/distributed.py
Normal file
0
src/axolotl/utils/schemas/distributed.py
Normal file
Reference in New Issue
Block a user