don't shrink embeddings unless told to
This commit is contained in:
@@ -667,6 +667,8 @@ class AxolotlInputConfig(
|
|||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
|
shrink_embeddings: Optional[bool] = None
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
|
|||||||
@@ -1053,9 +1053,12 @@ class ModelLoader:
|
|||||||
if self.cfg.resize_token_embeddings_to_32x
|
if self.cfg.resize_token_embeddings_to_32x
|
||||||
else len(self.tokenizer)
|
else len(self.tokenizer)
|
||||||
)
|
)
|
||||||
if (
|
if hasattr(self.model, "get_input_embeddings") and (
|
||||||
hasattr(self.model, "get_input_embeddings")
|
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
or (
|
||||||
|
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||||
|
and self.cfg.shrink_embeddings
|
||||||
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user