diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 136acc4a0..20d84e165 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -125,6 +125,9 @@ def normalize_config(cfg): with open(ds_config_path, encoding="utf-8") as f: cfg.deepspeed = json.load(f) + if cfg.sequence_parallel_size is None: + cfg.sequence_parallel_size = 1 + if cfg.saves_per_epoch: save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) if save_steps < 1.0: # prevent saves on every step diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index eb9792295..611fc4453 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1107,10 +1107,7 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sequence_parallel_config(cls, data): - if data.get("sequence_parallel_degree") is None: - data["sequence_parallel_degree"] = 1 - - if data.get("sequence_parallel_degree") > 1: + if data.get("sequence_parallel_degree", 1) > 1: if not data.get("flash_attention"): raise ValueError( "flash_attention: true must be set with sequence_parallel_degree > 1" diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 3fc315b2e..128d2d05c 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -280,6 +280,7 @@ class TestDeduplicateNonRL(unittest.TestCase): "batch_size": 10, "micro_batch_size": 10, "num_epochs": 1, + "sequence_parallel_degree": 1, } )