Tensor parallel w DeepSpeed AutoTP (#2574)
* support for deepspeed autotup * bump to latest deepspeed that supports deepcompile too * add deepcompile support too * fix total steps calculation for TP * setup fixture for tp * update ds config to ensure weights are gathered for checkpoint * fix duplicate validation names * chore: lint
This commit is contained in:
2
setup.py
2
setup.py
@@ -121,7 +121,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.1",
|
||||
"deepspeed==0.17.2",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -584,6 +584,12 @@ class AxolotlInputConfig(
|
||||
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
|
||||
},
|
||||
)
|
||||
deepcompile: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use deepcompile for faster training with deepspeed"
|
||||
},
|
||||
)
|
||||
fsdp: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "FSDP configuration"},
|
||||
@@ -629,7 +635,12 @@ class AxolotlInputConfig(
|
||||
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
|
||||
},
|
||||
)
|
||||
|
||||
tensor_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||
},
|
||||
)
|
||||
special_tokens: SpecialTokensConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Module with validation methods for config pydantic model."""
|
||||
|
||||
# pylint: disable=too-many-lines,too-many-boolean-expressions
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import (
|
||||
field_validator,
|
||||
@@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
@@ -872,6 +877,59 @@ class OptimizationValidationMixin:
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||
)
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deepcompile(cls, data):
|
||||
deepcompile = data.get("deepcompile")
|
||||
if deepcompile:
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError("DeepCompile is only supported with DeepSpeed")
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
if "compile" not in ds_config:
|
||||
ds_config["compile"] = {"deepcompile": True}
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SystemValidationMixin:
|
||||
"""Validation methods related to system and hardware configuration."""
|
||||
@@ -1126,6 +1184,12 @@ class ComplexValidationMixin:
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_size(self):
|
||||
if not self.tensor_parallel_size:
|
||||
self.tensor_parallel_size = 1
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
|
||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
@@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
)
|
||||
if cfg.dataloader_drop_last:
|
||||
@@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
@@ -65,6 +65,7 @@ def fixture_base_cfg():
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"sequence_parallel_degree": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
|
||||
Reference in New Issue
Block a user