allow overriding of model_config parameters from the YML (#853)
* allow overriding of model_config parameters from the YML * remove old logging, update readme * move the updating of model config to the load_model_config function * add warning for deprecated rope_scaling in the root of the YML config
This commit is contained in:
12
README.md
12
README.md
@@ -489,6 +489,14 @@ is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
@@ -756,10 +764,6 @@ landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
|
||||
@@ -369,6 +369,9 @@ def validate_config(cfg):
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
@@ -32,9 +31,14 @@ LOG = logging.getLogger("axolotl")
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
return AutoConfig.from_pretrained(
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if cfg.model_config:
|
||||
for key, val in cfg.model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
@@ -51,7 +55,7 @@ def load_tokenizer(cfg):
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -110,7 +114,6 @@ def load_model(
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
@@ -238,16 +241,9 @@ def load_model(
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config_kwargs = {}
|
||||
if cfg.rope_scaling:
|
||||
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config,
|
||||
**config_kwargs,
|
||||
)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
@@ -305,66 +301,55 @@ def load_model(
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
hasattr(model_config, "max_seq_len")
|
||||
and model_config.max_seq_len
|
||||
and cfg.sequence_len > model_config.max_seq_len
|
||||
):
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
model_config.max_seq_len = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
hasattr(model_config, "max_sequence_length")
|
||||
and model_config.max_sequence_length
|
||||
and cfg.sequence_len > model_config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
model_config.max_sequence_length = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
LOG.error(
|
||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||
)
|
||||
LOG.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
raise err
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
|
||||
Reference in New Issue
Block a user