Compare commits
6 Commits
maverick-e
...
zero3-8bit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afb8218c67 | ||
|
|
1ff78d6347 | ||
|
|
613a217142 | ||
|
|
127953af4e | ||
|
|
920ea77bdf | ||
|
|
ef60e3e851 |
83
src/axolotl/monkeypatch/modeling_zero3_int8_lora.py
Normal file
83
src/axolotl/monkeypatch/modeling_zero3_int8_lora.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
fix for zero3 8-bit lora
|
||||||
|
see https://github.com/huggingface/transformers/pull/32943/files
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import modeling_utils
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")
|
||||||
|
|
||||||
|
ORIGINAL_LOAD_CODE = """
|
||||||
|
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||||
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
value = getattr(module, tensor_name)
|
||||||
|
param_to = "cpu"
|
||||||
|
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||||
|
param_to = "meta"
|
||||||
|
value = type(value)(value.data.to(param_to), **value.__dict__)
|
||||||
|
setattr(module, tensor_name, value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_LOAD_CODE = """
|
||||||
|
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||||
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
value = getattr(module, tensor_name)
|
||||||
|
param_to = "cpu"
|
||||||
|
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||||
|
param_to = "meta"
|
||||||
|
val_kwargs = {}
|
||||||
|
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||||
|
val_kwargs["requires_grad"] = False
|
||||||
|
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||||
|
setattr(module, tensor_name, value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_modeling_state_dict_code() -> str:
|
||||||
|
load_code = inspect.getsource(
|
||||||
|
modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return load_code
|
||||||
|
|
||||||
|
|
||||||
|
def check_modeling_state_dict_code_is_patchable() -> bool:
|
||||||
|
load_code = get_modeling_state_dict_code()
|
||||||
|
return ORIGINAL_LOAD_CODE in load_code
|
||||||
|
|
||||||
|
|
||||||
|
def patch_modeling_state_dict_code():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the meta model loader for zero3 8-bit lora
|
||||||
|
"""
|
||||||
|
|
||||||
|
load_code = get_modeling_state_dict_code()
|
||||||
|
modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
|
||||||
|
load_code
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
ORIGINAL_LOAD_CODE in load_code
|
||||||
|
), "Original _load_state_dict_into_meta_model code not found"
|
||||||
|
|
||||||
|
load_code = load_code.replace(ORIGINAL_LOAD_CODE, PATCHED_LOAD_CODE)
|
||||||
|
load_code = load_code.replace(
|
||||||
|
"def _load_state_dict_into_meta_model(",
|
||||||
|
"def _fixed_load_state_dict_into_meta_model(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(modeling_utils):
|
||||||
|
if item in load_code:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.modeling_utils import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _load_state_dict_into_meta_model")
|
||||||
|
modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
|
||||||
@@ -437,6 +437,15 @@ def setup_deepspeed_env(cfg, stage=None):
|
|||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
if cfg.adapter and cfg.load_in_8bit:
|
||||||
|
from axolotl.monkeypatch.modeling_zero3_int8_lora import (
|
||||||
|
patch_modeling_state_dict_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
patch_modeling_state_dict_code()
|
||||||
|
except AssertionError:
|
||||||
|
LOG.warning("Failed to patch the meta model loading code")
|
||||||
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
||||||
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||||
|
|
||||||
|
|||||||
@@ -601,3 +601,61 @@ class TestMultiGPULlama:
|
|||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_8bit_lora_ds_zero3(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 15,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"deepspeed": "deepspeed_configs/zero3_bf16_cpuoffload_all.json",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user