Tensor parallel w DeepSpeed AutoTP (#2574)

* support for deepspeed autotup

* bump to latest deepspeed that supports deepcompile too

* add deepcompile support too

* fix total steps calculation for TP

* setup fixture for tp

* update ds config to ensure weights are gathered for checkpoint

* fix duplicate validation names

* chore: lint
This commit is contained in:
Wing Lian
2025-07-14 21:33:48 -04:00
committed by GitHub
parent 5cc16040a8
commit cd079b5536
5 changed files with 85 additions and 4 deletions

View File

@@ -121,7 +121,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.1",
"deepspeed==0.17.2",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -584,6 +584,12 @@ class AxolotlInputConfig(
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
},
)
deepcompile: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use deepcompile for faster training with deepspeed"
},
)
fsdp: list[str] | None = Field(
default=None,
json_schema_extra={"description": "FSDP configuration"},
@@ -629,7 +635,12 @@ class AxolotlInputConfig(
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
},
)
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={

View File

@@ -1,8 +1,11 @@
"""Module with validation methods for config pydantic model."""
# pylint: disable=too-many-lines,too-many-boolean-expressions
# pylint: disable=too-many-boolean-expressions
import json
import logging
import tempfile
from pathlib import Path
from pydantic import (
field_validator,
@@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
# pylint: disable=too-many-lines
LOG = logging.getLogger(__name__)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -872,6 +877,59 @@ class OptimizationValidationMixin:
)
return self
@model_validator(mode="before")
@classmethod
def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"):
raise ValueError(
"Tensor parallelism (TP) is only supported with DeepSpeed"
)
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
should_save = False
if "tensor_parallel" not in ds_config:
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
should_save = True
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save"
] = True
should_save = True
if should_save:
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data
@model_validator(mode="before")
@classmethod
def check_deepcompile(cls, data):
deepcompile = data.get("deepcompile")
if deepcompile:
if not data.get("deepspeed"):
raise ValueError("DeepCompile is only supported with DeepSpeed")
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
if "compile" not in ds_config:
ds_config["compile"] = {"deepcompile": True}
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json")
return data
class SystemValidationMixin:
"""Validation methods related to system and hardware configuration."""
@@ -1126,6 +1184,12 @@ class ComplexValidationMixin:
)
return self
@model_validator(mode="after")
def check_tensor_parallel_size(self):
if not self.tensor_parallel_size:
self.tensor_parallel_size = 1
return self
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:

View File

@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.tensor_parallel_size
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
data_loader_len
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.tensor_parallel_size
)
)
if cfg.dataloader_drop_last:
@@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.tensor_parallel_size
/ cfg.batch_size
)
)

View File

@@ -65,6 +65,7 @@ def fixture_base_cfg():
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"tensor_parallel_size": 1,
# Dtype
"fp16": False,
"bf16": False,