add tabs back to code check
This commit is contained in:
@@ -12,19 +12,19 @@ from axolotl.monkeypatch.utils import detab_code
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
ORIGINAL_PREPARE_CODE = """
|
ORIGINAL_PREPARE_CODE = """
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if (
|
if (
|
||||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||||
) and param.__class__.__name__ != "Params4bit":
|
) and param.__class__.__name__ != "Params4bit":
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATCHED_PREPARE_CODE = """
|
PATCHED_PREPARE_CODE = """
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if (
|
if (
|
||||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||||
) and param.__class__.__name__ != "Params4bit" and "norm" in name:
|
) and param.__class__.__name__ != "Params4bit" and "norm" in name:
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user