diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index fed88a0ed..d777128b1 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -12,19 +12,19 @@ from axolotl.monkeypatch.utils import detab_code LOG = logging.getLogger(__name__) ORIGINAL_PREPARE_CODE = """ - for param in model.parameters(): - if ( - (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) - ) and param.__class__.__name__ != "Params4bit": - param.data = param.data.to(torch.float32) + for param in model.parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit": + param.data = param.data.to(torch.float32) """ PATCHED_PREPARE_CODE = """ - for name, param in model.named_parameters(): - if ( - (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) - ) and param.__class__.__name__ != "Params4bit" and "norm" in name: - param.data = param.data.to(torch.float32) + for name, param in model.named_parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit" and "norm" in name: + param.data = param.data.to(torch.float32) """