don't shrink embeddings unless told to
This commit is contained in:
@@ -667,6 +667,8 @@ class AxolotlInputConfig(
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: Optional[bool] = None
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
|
||||
@@ -1053,9 +1053,12 @@ class ModelLoader:
|
||||
if self.cfg.resize_token_embeddings_to_32x
|
||||
else len(self.tokenizer)
|
||||
)
|
||||
if (
|
||||
hasattr(self.model, "get_input_embeddings")
|
||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
||||
if hasattr(self.model, "get_input_embeddings") and (
|
||||
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
or (
|
||||
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||
and self.cfg.shrink_embeddings
|
||||
)
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
|
||||
Reference in New Issue
Block a user