Tensor parallel w DeepSpeed AutoTP (#2574)
* support for deepspeed autotup * bump to latest deepspeed that supports deepcompile too * add deepcompile support too * fix total steps calculation for TP * setup fixture for tp * update ds config to ensure weights are gathered for checkpoint * fix duplicate validation names * chore: lint
This commit is contained in:
2
setup.py
2
setup.py
@@ -121,7 +121,7 @@ extras_require = {
|
|||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.17.1",
|
"deepspeed==0.17.2",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -584,6 +584,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
|
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
deepcompile: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use deepcompile for faster training with deepspeed"
|
||||||
|
},
|
||||||
|
)
|
||||||
fsdp: list[str] | None = Field(
|
fsdp: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "FSDP configuration"},
|
json_schema_extra={"description": "FSDP configuration"},
|
||||||
@@ -629,7 +635,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
|
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
tensor_parallel_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||||
|
},
|
||||||
|
)
|
||||||
special_tokens: SpecialTokensConfig | None = Field(
|
special_tokens: SpecialTokensConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
"""Module with validation methods for config pydantic model."""
|
"""Module with validation methods for config pydantic model."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines,too-many-boolean-expressions
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
field_validator,
|
field_validator,
|
||||||
@@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||||
|
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
@@ -872,6 +877,59 @@ class OptimizationValidationMixin:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
|
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||||
|
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||||
|
if not data.get("deepspeed"):
|
||||||
|
raise ValueError(
|
||||||
|
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||||
|
)
|
||||||
|
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||||
|
ds_config = json.load(ds_fin)
|
||||||
|
should_save = False
|
||||||
|
if "tensor_parallel" not in ds_config:
|
||||||
|
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||||
|
should_save = True
|
||||||
|
if (
|
||||||
|
"gather_16bit_weights_on_model_save"
|
||||||
|
not in ds_config["zero_optimization"]
|
||||||
|
):
|
||||||
|
ds_config["zero_optimization"][
|
||||||
|
"gather_16bit_weights_on_model_save"
|
||||||
|
] = True
|
||||||
|
should_save = True
|
||||||
|
if should_save:
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
with open(
|
||||||
|
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||||
|
) as ds_fout:
|
||||||
|
json.dump(ds_config, ds_fout, indent=4)
|
||||||
|
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_deepcompile(cls, data):
|
||||||
|
deepcompile = data.get("deepcompile")
|
||||||
|
if deepcompile:
|
||||||
|
if not data.get("deepspeed"):
|
||||||
|
raise ValueError("DeepCompile is only supported with DeepSpeed")
|
||||||
|
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||||
|
ds_config = json.load(ds_fin)
|
||||||
|
if "compile" not in ds_config:
|
||||||
|
ds_config["compile"] = {"deepcompile": True}
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
with open(
|
||||||
|
Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8"
|
||||||
|
) as ds_fout:
|
||||||
|
json.dump(ds_config, ds_fout, indent=4)
|
||||||
|
data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class SystemValidationMixin:
|
class SystemValidationMixin:
|
||||||
"""Validation methods related to system and hardware configuration."""
|
"""Validation methods related to system and hardware configuration."""
|
||||||
@@ -1126,6 +1184,12 @@ class ComplexValidationMixin:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_tensor_parallel_size(self):
|
||||||
|
if not self.tensor_parallel_size:
|
||||||
|
self.tensor_parallel_size = 1
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_sequence_parallel_degree(self):
|
def check_sequence_parallel_degree(self):
|
||||||
if not self.sequence_parallel_degree:
|
if not self.sequence_parallel_degree:
|
||||||
|
|||||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.sequence_parallel_degree
|
||||||
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||||
@@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.floor(
|
math.floor(
|
||||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
data_loader_len
|
||||||
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if cfg.dataloader_drop_last:
|
if cfg.dataloader_drop_last:
|
||||||
@@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
len(train_dataset)
|
len(train_dataset)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.sequence_parallel_degree
|
||||||
|
* cfg.tensor_parallel_size
|
||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def fixture_base_cfg():
|
|||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
"dataloader_prefetch_factor": 2,
|
"dataloader_prefetch_factor": 2,
|
||||||
"sequence_parallel_degree": 1,
|
"sequence_parallel_degree": 1,
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
# Dtype
|
# Dtype
|
||||||
"fp16": False,
|
"fp16": False,
|
||||||
"bf16": False,
|
"bf16": False,
|
||||||
|
|||||||
Reference in New Issue
Block a user