add tabs back to code check

This commit is contained in:
Wing Lian
2025-05-03 02:46:50 -04:00
parent 140083a828
commit 99095573c3

View File

@@ -12,19 +12,19 @@ from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
ORIGINAL_PREPARE_CODE = """ ORIGINAL_PREPARE_CODE = """
for param in model.parameters(): for param in model.parameters():
if ( if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16) (param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit": ) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
""" """
PATCHED_PREPARE_CODE = """ PATCHED_PREPARE_CODE = """
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if ( if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16) (param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit" and "norm" in name: ) and param.__class__.__name__ != "Params4bit" and "norm" in name:
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
""" """