log warning re: logged losses / gradient scaling per rank
This commit is contained in:
@@ -1171,6 +1171,18 @@ class AxolotlInputConfig(
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={value}. Please note that logged losses may "
|
||||
"differ slightly to the non-SP losses due to transformers Trainer "
|
||||
"implementation details. Please see "
|
||||
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user