better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090)
This commit is contained in:
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
|
|||||||
if family in ["a10", "a10g"]:
|
if family in ["a10", "a10g"]:
|
||||||
return modal.gpu.A10G(count=count)
|
return modal.gpu.A10G(count=count)
|
||||||
if family == "h100":
|
if family == "h100":
|
||||||
return modal.gpu.H100(count=count)
|
return f"H100:{count}"
|
||||||
if family == "t4":
|
if family == "t4":
|
||||||
return modal.gpu.T4(count=count)
|
return modal.gpu.T4(count=count)
|
||||||
if family == "l4":
|
if family == "l4":
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
|||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||||
if module.base_layer.bias is not None:
|
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
|
||||||
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||||
log_bias_dtype_mismatch = True
|
log_bias_dtype_mismatch = True
|
||||||
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||||
|
|||||||
@@ -253,7 +253,9 @@ def save_trained_model(
|
|||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
if ( # pylint: disable=too-many-nested-blocks
|
||||||
|
trainer.is_fsdp_enabled or cfg.fsdp_config
|
||||||
|
):
|
||||||
if cfg.fsdp_config or cfg.fsdp:
|
if cfg.fsdp_config or cfg.fsdp:
|
||||||
if cfg.fsdp_config.final_state_dict_type:
|
if cfg.fsdp_config.final_state_dict_type:
|
||||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||||
@@ -285,6 +287,8 @@ def save_trained_model(
|
|||||||
if trainer.accelerator.is_main_process:
|
if trainer.accelerator.is_main_process:
|
||||||
# move all files in merged_path to cfg.output_dir
|
# move all files in merged_path to cfg.output_dir
|
||||||
for merged_file in Path(merged_path).iterdir():
|
for merged_file in Path(merged_path).iterdir():
|
||||||
|
if (Path(cfg.output_dir) / merged_file.name).exists():
|
||||||
|
(Path(cfg.output_dir) / merged_file.name).unlink()
|
||||||
shutil.move(str(merged_file), cfg.output_dir)
|
shutil.move(str(merged_file), cfg.output_dir)
|
||||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||||
|
|||||||
Reference in New Issue
Block a user