Merge pull request #179 from OpenAccess-AI-Collective/fix-max_seq_len

fix for max sequence len across different model types
This commit is contained in:
Wing Lian
2023-06-09 20:52:03 -04:00
committed by GitHub

View File

@@ -255,8 +255,15 @@ def load_model(
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and cfg.sequence_len > config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,