Configurable embeddings upcast (#2621)
* fsdp embeddings should be float32 per comment * patch peft to not upcast everything * add tabs back to code check * fix import * add configurable option and fix check * add check for dtypes * move embeddings test to patch dir * fix test * fix comment and logic
This commit is contained in:
@@ -32,6 +32,8 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
|
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
|
||||||
|
embeddings_skip_upcast:
|
||||||
# Whether to load the model with randomly initialized weights. Useful for
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
# pre-training a model from scratch or debugging purposes.
|
# pre-training a model from scratch or debugging purposes.
|
||||||
random_init_weights:
|
random_init_weights:
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
78
src/axolotl/monkeypatch/peft/utils.py
Normal file
78
src/axolotl/monkeypatch/peft/utils.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
Patch prepare_model_for_kbit_training to not upcast everything
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import peft
|
||||||
|
|
||||||
|
import axolotl
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_PREPARE_CODE = """
|
||||||
|
for param in model.parameters():
|
||||||
|
if (
|
||||||
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||||
|
) and param.__class__.__name__ != "Params4bit":
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_PREPARE_CODE = """
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if (
|
||||||
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||||
|
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_peft_prep_code() -> str:
|
||||||
|
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
|
||||||
|
def check_peft_prep_code_is_patchable() -> bool:
|
||||||
|
prep_code = get_peft_prep_code()
|
||||||
|
prep_code, _ = detab_code(prep_code)
|
||||||
|
return ORIGINAL_PREPARE_CODE in prep_code
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_prep_code():
|
||||||
|
"""
|
||||||
|
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
prep_code = get_peft_prep_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
|
||||||
|
prep_code
|
||||||
|
)
|
||||||
|
prep_code, _ = detab_code(prep_code)
|
||||||
|
if ORIGINAL_PREPARE_CODE not in prep_code:
|
||||||
|
return
|
||||||
|
|
||||||
|
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
|
||||||
|
prep_code = prep_code.replace(
|
||||||
|
"def prepare_model_for_kbit_training(",
|
||||||
|
"def fixed_prepare_model_for_kbit_training(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(peft.utils.other):
|
||||||
|
if item in prep_code:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
|
||||||
|
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||||
@@ -566,6 +566,11 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_accelerate_fsdp_utils()
|
patch_accelerate_fsdp_utils()
|
||||||
|
|
||||||
|
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||||
|
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
|
||||||
|
|
||||||
|
patch_peft_prep_code()
|
||||||
|
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.flex_attention:
|
||||||
from axolotl.monkeypatch.attention.flex_attn import (
|
from axolotl.monkeypatch.attention.flex_attn import (
|
||||||
patch_flex_make_mask,
|
patch_flex_make_mask,
|
||||||
@@ -1185,7 +1190,7 @@ class ModelLoader:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_model(self, qlora_fsdp) -> None:
|
def prepare_model(self, qlora_fsdp: bool) -> None:
|
||||||
skip_prepare_model_for_kbit_training = False
|
skip_prepare_model_for_kbit_training = False
|
||||||
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
|
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
@@ -1315,7 +1320,10 @@ class ModelLoader:
|
|||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||||
if not self.cfg.fsdp:
|
if not self.cfg.fsdp:
|
||||||
# FSDP doesn't like mixed Float and BFloat16
|
# we don't run this during FSDP because this will leave mixed
|
||||||
|
# float and bfloat16 dtypes in the model which FSDP doesn't like
|
||||||
|
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
||||||
|
embedding_modules = []
|
||||||
self.convert_embedding_modules_dtype(
|
self.convert_embedding_modules_dtype(
|
||||||
embedding_modules,
|
embedding_modules,
|
||||||
dist_dtype=torch.float32,
|
dist_dtype=torch.float32,
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ class AxolotlInputConfig(
|
|||||||
mean_resizing_embeddings: bool | None = False
|
mean_resizing_embeddings: bool | None = False
|
||||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
shrink_embeddings: bool | None = None
|
shrink_embeddings: bool | None = None
|
||||||
|
embeddings_skip_upcast: bool | None = None
|
||||||
|
|
||||||
rl: RLType | None = None
|
rl: RLType | None = None
|
||||||
trl: TRLConfig | None = Field(
|
trl: TRLConfig | None = Field(
|
||||||
|
|||||||
63
tests/e2e/patched/test_peft_embeddings.py
Normal file
63
tests/e2e/patched/test_peft_embeddings.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Test case for handling embeddings when using peft
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.train import setup_model_and_tokenizer
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlamaPeftEmbeddings:
|
||||||
|
"""
|
||||||
|
test class for handling embeddings when using peft
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_peft_embeddings_upcast(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": False,
|
||||||
|
"bf16": "auto",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"embeddings_skip_upcast": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
model, _, _, _ = setup_model_and_tokenizer(cfg)
|
||||||
|
|
||||||
|
# Check if the embeddings are upcast correctly
|
||||||
|
# only embed_tokens is a parameter that may be upcast
|
||||||
|
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
|
||||||
|
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16
|
||||||
Reference in New Issue
Block a user