diff --git a/setup.py b/setup.py index ff8bd2c5c..df9a23154 100644 --- a/setup.py +++ b/setup.py @@ -121,7 +121,7 @@ extras_require = { "yunchang==0.6.0", ], "deepspeed": [ - "deepspeed==0.17.1", + "deepspeed==0.17.2", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1726feb67..909fd637c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -584,6 +584,12 @@ class AxolotlInputConfig( "description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json" }, ) + deepcompile: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use deepcompile for faster training with deepspeed" + }, + ) fsdp: list[str] | None = Field( default=None, json_schema_extra={"description": "FSDP configuration"}, @@ -629,7 +635,12 @@ class AxolotlInputConfig( "description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case." }, ) - + tensor_parallel_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP." + }, + ) special_tokens: SpecialTokensConfig | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index db3fd0a1c..56a70ec48 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1,8 +1,11 @@ """Module with validation methods for config pydantic model.""" -# pylint: disable=too-many-lines,too-many-boolean-expressions +# pylint: disable=too-many-boolean-expressions +import json import logging +import tempfile +from pathlib import Path from pydantic import ( field_validator, @@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +# pylint: disable=too-many-lines + LOG = logging.getLogger(__name__) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} @@ -872,6 +877,59 @@ class OptimizationValidationMixin: ) return self + @model_validator(mode="before") + @classmethod + def check_tensor_parallel_size_update_ds_json(cls, data): + tensor_parallel_size = data.get("tensor_parallel_size") + if tensor_parallel_size is not None and tensor_parallel_size > 1: + if not data.get("deepspeed"): + raise ValueError( + "Tensor parallelism (TP) is only supported with DeepSpeed" + ) + with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: + ds_config = json.load(ds_fin) + should_save = False + if "tensor_parallel" not in ds_config: + ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size} + should_save = True + if ( + "gather_16bit_weights_on_model_save" + not in ds_config["zero_optimization"] + ): + ds_config["zero_optimization"][ + "gather_16bit_weights_on_model_save" + ] = True + should_save = True + if should_save: + temp_dir = tempfile.mkdtemp() + with open( + Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" + ) as ds_fout: + json.dump(ds_config, ds_fout, indent=4) + data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") + + return data + + @model_validator(mode="before") + @classmethod + def check_deepcompile(cls, data): + deepcompile = data.get("deepcompile") + if deepcompile: + if not data.get("deepspeed"): + raise ValueError("DeepCompile is only supported with DeepSpeed") + with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: + ds_config = json.load(ds_fin) + if "compile" not in ds_config: + ds_config["compile"] = {"deepcompile": True} + temp_dir = tempfile.mkdtemp() + with open( + Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8" + ) as ds_fout: + json.dump(ds_config, ds_fout, indent=4) + data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json") + + return data + class SystemValidationMixin: """Validation methods related to system and hardware configuration.""" @@ -1126,6 +1184,12 @@ class ComplexValidationMixin: ) return self + @model_validator(mode="after") + def check_tensor_parallel_size(self): + if not self.tensor_parallel_size: + self.tensor_parallel_size = 1 + return self + @model_validator(mode="after") def check_sequence_parallel_degree(self): if not self.sequence_parallel_degree: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a512d6400..8371b2dd7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) * cfg.num_epochs * cfg.sequence_parallel_degree + * cfg.tensor_parallel_size ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" @@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): # on the agreed on value for sample_packing_eff_est total_num_steps = int( math.floor( - data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree + data_loader_len + * cfg.num_epochs + * cfg.sequence_parallel_degree + * cfg.tensor_parallel_size ) ) if cfg.dataloader_drop_last: @@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): len(train_dataset) * cfg.num_epochs * cfg.sequence_parallel_degree + * cfg.tensor_parallel_size / cfg.batch_size ) ) diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index e66b8e009..0053b4d27 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -65,6 +65,7 @@ def fixture_base_cfg(): "dataloader_pin_memory": True, "dataloader_prefetch_factor": 2, "sequence_parallel_degree": 1, + "tensor_parallel_size": 1, # Dtype "fp16": False, "bf16": False,