update tp validation
This commit is contained in:
@@ -913,31 +913,30 @@ class OptimizationValidationMixin:
|
|||||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||||
if not data.get("deepspeed"):
|
if data.get("deepspeed"):
|
||||||
raise ValueError(
|
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
ds_config = json.load(ds_fin)
|
||||||
)
|
should_save = False
|
||||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
if "tensor_parallel" not in ds_config:
|
||||||
ds_config = json.load(ds_fin)
|
ds_config["tensor_parallel"] = {
|
||||||
should_save = False
|
"autotp_size": tensor_parallel_size
|
||||||
if "tensor_parallel" not in ds_config:
|
}
|
||||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
should_save = True
|
||||||
should_save = True
|
if (
|
||||||
if (
|
|
||||||
"gather_16bit_weights_on_model_save"
|
|
||||||
not in ds_config["zero_optimization"]
|
|
||||||
):
|
|
||||||
ds_config["zero_optimization"][
|
|
||||||
"gather_16bit_weights_on_model_save"
|
"gather_16bit_weights_on_model_save"
|
||||||
] = True
|
not in ds_config["zero_optimization"]
|
||||||
should_save = True
|
):
|
||||||
if should_save:
|
ds_config["zero_optimization"][
|
||||||
temp_dir = tempfile.mkdtemp()
|
"gather_16bit_weights_on_model_save"
|
||||||
with open(
|
] = True
|
||||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
should_save = True
|
||||||
) as ds_fout:
|
if should_save:
|
||||||
json.dump(ds_config, ds_fout, indent=4)
|
temp_dir = tempfile.mkdtemp()
|
||||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
with open(
|
||||||
|
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||||
|
) as ds_fout:
|
||||||
|
json.dump(ds_config, ds_fout, indent=4)
|
||||||
|
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user