diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index ed2b70307..241d6e387 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -913,31 +913,30 @@ class OptimizationValidationMixin: def check_tensor_parallel_size_update_ds_json(cls, data): tensor_parallel_size = data.get("tensor_parallel_size") if tensor_parallel_size is not None and tensor_parallel_size > 1: - if not data.get("deepspeed"): - raise ValueError( - "Tensor parallelism (TP) is only supported with DeepSpeed" - ) - with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: - ds_config = json.load(ds_fin) - should_save = False - if "tensor_parallel" not in ds_config: - ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size} - should_save = True - if ( - "gather_16bit_weights_on_model_save" - not in ds_config["zero_optimization"] - ): - ds_config["zero_optimization"][ + if data.get("deepspeed"): + with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: + ds_config = json.load(ds_fin) + should_save = False + if "tensor_parallel" not in ds_config: + ds_config["tensor_parallel"] = { + "autotp_size": tensor_parallel_size + } + should_save = True + if ( "gather_16bit_weights_on_model_save" - ] = True - should_save = True - if should_save: - temp_dir = tempfile.mkdtemp() - with open( - Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" - ) as ds_fout: - json.dump(ds_config, ds_fout, indent=4) - data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") + not in ds_config["zero_optimization"] + ): + ds_config["zero_optimization"][ + "gather_16bit_weights_on_model_save" + ] = True + should_save = True + if should_save: + temp_dir = tempfile.mkdtemp() + with open( + Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" + ) as ds_fout: + json.dump(ds_config, ds_fout, indent=4) + data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") return data