add option for resizing embeddings when adding new tokens (#2000)
* add option for resizing embeddings when adding new tokens * let's just be opinonated about this setting and set it to False
This commit is contained in:
@@ -549,6 +549,7 @@ class AxolotlInputConfig(
|
|||||||
resume_from_checkpoint: Optional[str] = None
|
resume_from_checkpoint: Optional[str] = None
|
||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
|
|||||||
@@ -1042,7 +1042,10 @@ class ModelLoader:
|
|||||||
hasattr(self.model, "get_input_embeddings")
|
hasattr(self.model, "get_input_embeddings")
|
||||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
):
|
):
|
||||||
self.model.resize_token_embeddings(embeddings_len)
|
resize_kwargs = {}
|
||||||
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
self.model.tie_weights()
|
self.model.tie_weights()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user