Update get_unpad_data patching for multipack (#2013)
* Update `get_unpad_data` patching for multipack * Update src/axolotl/utils/models.py * Update src/axolotl/utils/models.py * Add test case --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
"""multipack patching for v2 of sample packing"""
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -27,71 +28,28 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||||
if model_type == "gemmoe":
|
if has_remote_code:
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
patch_remote(model_name)
|
||||||
elif model_type == "deepseek_v2":
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
return
|
|
||||||
|
|
||||||
# retain for legacy
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
if model_type == "mixtral":
|
patch_mixtral_moe_forward_zero3()
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
elif model_type == "llama":
|
|
||||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "mistral":
|
|
||||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2":
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2_moe":
|
|
||||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "falcon":
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "phi":
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma":
|
|
||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma2":
|
|
||||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "starcoder2":
|
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name):
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# we need to load the model here in order for modeling_* to be available
|
# we need to load the model here in order for modeling_* to be available
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
parts = model_config.__class__.__module__.split(".")
|
||||||
|
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
||||||
|
module_name = ".".join(parts)
|
||||||
modeling_arch = importlib.import_module(module_name)
|
modeling_arch = importlib.import_module(module_name)
|
||||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
if hasattr(modeling_arch, "_get_unpad_data"):
|
||||||
|
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
|||||||
@@ -395,10 +395,17 @@ class ModelLoader:
|
|||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
|
has_remote_code = (
|
||||||
|
"auto_map" in self.model_config
|
||||||
|
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||||
|
)
|
||||||
|
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||||
|
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||||
|
has_remote_code = self.cfg.trust_remote_code
|
||||||
patch_for_multipack(
|
patch_for_multipack(
|
||||||
self.cfg.model_config_type,
|
self.cfg.model_config_type,
|
||||||
model_name=self.cfg.base_model,
|
model_name=self.cfg.base_model,
|
||||||
is_remote_code=self.cfg.trust_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model:
|
if self.cfg.is_llama_derived_model:
|
||||||
|
|||||||
66
tests/e2e/test_llama.py
Normal file
66
tests/e2e/test_llama.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft_trust_remote_code(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
Reference in New Issue
Block a user