update tp validation

This commit is contained in:
Wing Lian
2025-07-22 21:20:57 -04:00
parent 8fe4758e94
commit 9a2da4d9f0

View File

@@ -913,31 +913,30 @@ class OptimizationValidationMixin:
def check_tensor_parallel_size_update_ds_json(cls, data): def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size") tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1: if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"): if data.get("deepspeed"):
raise ValueError( with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
"Tensor parallelism (TP) is only supported with DeepSpeed" ds_config = json.load(ds_fin)
) should_save = False
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: if "tensor_parallel" not in ds_config:
ds_config = json.load(ds_fin) ds_config["tensor_parallel"] = {
should_save = False "autotp_size": tensor_parallel_size
if "tensor_parallel" not in ds_config: }
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size} should_save = True
should_save = True if (
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save" "gather_16bit_weights_on_model_save"
] = True not in ds_config["zero_optimization"]
should_save = True ):
if should_save: ds_config["zero_optimization"][
temp_dir = tempfile.mkdtemp() "gather_16bit_weights_on_model_save"
with open( ] = True
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" should_save = True
) as ds_fout: if should_save:
json.dump(ds_config, ds_fout, indent=4) temp_dir = tempfile.mkdtemp()
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data return data