fix
This commit is contained in:
@@ -125,6 +125,9 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
|
if cfg.sequence_parallel_size is None:
|
||||||
|
cfg.sequence_parallel_size = 1
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
|
|||||||
@@ -1107,10 +1107,7 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sequence_parallel_config(cls, data):
|
def check_sequence_parallel_config(cls, data):
|
||||||
if data.get("sequence_parallel_degree") is None:
|
if data.get("sequence_parallel_degree", 1) > 1:
|
||||||
data["sequence_parallel_degree"] = 1
|
|
||||||
|
|
||||||
if data.get("sequence_parallel_degree") > 1:
|
|
||||||
if not data.get("flash_attention"):
|
if not data.get("flash_attention"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
|
|||||||
@@ -280,6 +280,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"batch_size": 10,
|
"batch_size": 10,
|
||||||
"micro_batch_size": 10,
|
"micro_batch_size": 10,
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
|
"sequence_parallel_degree": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user