fix
This commit is contained in:
@@ -125,6 +125,9 @@ def normalize_config(cfg):
|
||||
with open(ds_config_path, encoding="utf-8") as f:
|
||||
cfg.deepspeed = json.load(f)
|
||||
|
||||
if cfg.sequence_parallel_size is None:
|
||||
cfg.sequence_parallel_size = 1
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
|
||||
@@ -1107,10 +1107,7 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, data):
|
||||
if data.get("sequence_parallel_degree") is None:
|
||||
data["sequence_parallel_degree"] = 1
|
||||
|
||||
if data.get("sequence_parallel_degree") > 1:
|
||||
if data.get("sequence_parallel_degree", 1) > 1:
|
||||
if not data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
|
||||
@@ -280,6 +280,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
"batch_size": 10,
|
||||
"micro_batch_size": 10,
|
||||
"num_epochs": 1,
|
||||
"sequence_parallel_degree": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user