allow overriding of model_config parameters from the YML (#853)

* allow overriding of model_config parameters from the YML

* remove old logging, update readme

* move the updating of model config to the load_model_config function

* add warning for deprecated rope_scaling in the root of the YML config
This commit is contained in:
Wing Lian
2023-11-15 23:47:08 -05:00
committed by GitHub
parent b3a61e8ce2
commit 1bc11868eb
3 changed files with 32 additions and 40 deletions

View File

@@ -489,6 +489,14 @@ is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default # Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
# optional overrides to the base model configuration
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size gptq_groupsize: 128 # group size
@@ -756,10 +764,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only # LLaMA only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:

View File

@@ -369,6 +369,9 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig, GPTQConfig,
LlamaConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
@@ -32,9 +31,14 @@ LOG = logging.getLogger("axolotl")
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code = cfg.trust_remote_code is True
return AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code model_config_name, trust_remote_code=trust_remote_code
) )
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
return model_config
def load_tokenizer(cfg): def load_tokenizer(cfg):
@@ -51,7 +55,7 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type: if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained( tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
@@ -110,7 +114,6 @@ def load_model(
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
base_model = cfg.base_model base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type model_type = cfg.model_type
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
@@ -238,16 +241,9 @@ def load_model(
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
@@ -305,66 +301,55 @@ def load_model(
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if ( if (
hasattr(config, "max_seq_len") hasattr(model_config, "max_seq_len")
and config.max_seq_len and model_config.max_seq_len
and cfg.sequence_len > config.max_seq_len and cfg.sequence_len > model_config.max_seq_len
): ):
config.max_seq_len = cfg.sequence_len model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(model_config, "max_sequence_length")
and config.max_sequence_length and model_config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > model_config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
LOG.exception(err) LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained( raise err
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
embeddings_len = ( embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32 math.ceil(len(tokenizer) / 32) * 32