diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 42ece6023..4083fcc22 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1171,6 +1171,18 @@ class AxolotlInputConfig( "or `pip install ring-flash-attn>=0.1.4`." ) from exception + # TODO: monkeypatch / callback to average losses correctly across SP ranks + # / fix gradient scaling across SP ranks. Losses, grads should be scaled + # according to the proportion of non-padding tokens per rank. + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"sequence_parallel_degree={value}. Please note that logged losses may " + "differ slightly to the non-SP losses due to transformers Trainer " + "implementation details. Please see " + "https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) + return value @model_validator(mode="before")