style correction
This commit is contained in:
@@ -252,11 +252,16 @@ def load_model(
|
|||||||
)
|
)
|
||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||||
# when training starts
|
# when training starts
|
||||||
if hasattr(config, "max_seq_len") and config.max_seq_len and cfg.sequence_len > config.max_seq_len:
|
if (
|
||||||
|
hasattr(config, "max_seq_len")
|
||||||
|
and config.max_seq_len
|
||||||
|
and cfg.sequence_len > config.max_seq_len
|
||||||
|
):
|
||||||
config.max_seq_len = cfg.sequence_len
|
config.max_seq_len = cfg.sequence_len
|
||||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
elif (
|
elif (
|
||||||
hasattr(config, "max_sequence_length") and config.max_sequence_length
|
hasattr(config, "max_sequence_length")
|
||||||
|
and config.max_sequence_length
|
||||||
and cfg.sequence_len > config.max_sequence_length
|
and cfg.sequence_len > config.max_sequence_length
|
||||||
):
|
):
|
||||||
config.max_sequence_length = cfg.sequence_len
|
config.max_sequence_length = cfg.sequence_len
|
||||||
@@ -289,7 +294,8 @@ def load_model(
|
|||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings
|
hasattr(model.config, "max_position_embeddings")
|
||||||
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user