log warning re: logged losses / gradient scaling per rank

This commit is contained in:
Dan Saunders
2025-04-07 18:46:58 +00:00
parent c64c881460
commit 954b989e88

View File

@@ -1171,6 +1171,18 @@ class AxolotlInputConfig(
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return value
@model_validator(mode="before")