issue #205 bugfix

This commit is contained in:
maciej.karasek
2023-06-14 16:59:57 +02:00
parent 16bb6276a5
commit 556fe408b3

View File

@@ -252,11 +252,11 @@ def load_model(
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
if hasattr(config, "max_seq_len") and config.max_seq_len and cfg.sequence_len > config.max_seq_len:
config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
hasattr(config, "max_sequence_length") and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
@@ -289,7 +289,7 @@ def load_model(
model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
):
logging.warning(