This commit is contained in:
Dan Saunders
2025-03-14 01:42:10 +00:00
parent cb3a9e99a3
commit a6ef6c7764
3 changed files with 5 additions and 4 deletions

View File

@@ -125,6 +125,9 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_size is None:
cfg.sequence_parallel_size = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step

View File

@@ -1107,10 +1107,7 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_sequence_parallel_config(cls, data):
if data.get("sequence_parallel_degree") is None:
data["sequence_parallel_degree"] = 1
if data.get("sequence_parallel_degree") > 1:
if data.get("sequence_parallel_degree", 1) > 1:
if not data.get("flash_attention"):
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"

View File

@@ -280,6 +280,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
"batch_size": 10,
"micro_batch_size": 10,
"num_epochs": 1,
"sequence_parallel_degree": 1,
}
)