Merge pull request #206 from MaciejKarasek/issue205

issue #205 bugfix
This commit is contained in:
Wing Lian
2023-06-14 14:23:38 -04:00
committed by GitHub

View File

@@ -252,11 +252,16 @@ def load_model(
) )
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len: if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
):
config.max_seq_len = cfg.sequence_len config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}") logging.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len config.max_sequence_length = cfg.sequence_len
@@ -290,6 +295,7 @@ def load_model(
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings and cfg.sequence_len >= model.config.max_position_embeddings
): ):
logging.warning( logging.warning(