diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0dcb4cd77..253bdcbd8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -333,7 +333,7 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - # LlamaRMSNorm layers are in fp32 after kit call, so we need to + # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. if cfg.flash_attention and cfg.is_llama_derived_model: for name, module in model.named_modules():