diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 103c707f2..c6d380267 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -252,11 +252,16 @@ def load_model( ) # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts - if hasattr(config, "max_seq_len") and config.max_seq_len and cfg.sequence_len > config.max_seq_len: + if ( + hasattr(config, "max_seq_len") + and config.max_seq_len + and cfg.sequence_len > config.max_seq_len + ): config.max_seq_len = cfg.sequence_len logging.warning(f"increasing context length to {cfg.sequence_len}") elif ( - hasattr(config, "max_sequence_length") and config.max_sequence_length + hasattr(config, "max_sequence_length") + and config.max_sequence_length and cfg.sequence_len > config.max_sequence_length ): config.max_sequence_length = cfg.sequence_len @@ -289,7 +294,8 @@ def load_model( model.resize_token_embeddings(embeddings_len) if ( - hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings + hasattr(model.config, "max_position_embeddings") + and model.config.max_position_embeddings and cfg.sequence_len >= model.config.max_position_embeddings ): logging.warning(