log warning re: logged losses / gradient scaling per rank
This commit is contained in:
@@ -1171,6 +1171,18 @@ class AxolotlInputConfig(
|
|||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
) from exception
|
) from exception
|
||||||
|
|
||||||
|
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||||
|
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||||
|
# according to the proportion of non-padding tokens per rank.
|
||||||
|
LOG.warning(
|
||||||
|
"Sequence parallelism (SP) is enabled with "
|
||||||
|
f"sequence_parallel_degree={value}. Please note that logged losses may "
|
||||||
|
"differ slightly to the non-SP losses due to transformers Trainer "
|
||||||
|
"implementation details. Please see "
|
||||||
|
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
|
"for more details."
|
||||||
|
)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
Reference in New Issue
Block a user