add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci]
This commit is contained in:
@@ -409,6 +409,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
|
if self.cfg.hub_revision:
|
||||||
|
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
|
||||||
|
|
||||||
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
|
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
|
||||||
# save_strategy and save_steps
|
# save_strategy and save_steps
|
||||||
if self.cfg.save_steps:
|
if self.cfg.save_steps:
|
||||||
|
|||||||
@@ -120,6 +120,12 @@ class ModelOutputConfig(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "how to push checkpoints to hub"},
|
json_schema_extra={"description": "how to push checkpoints to hub"},
|
||||||
)
|
)
|
||||||
|
hub_revision: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "branch/revision to push to on hub (default: main)"
|
||||||
|
},
|
||||||
|
)
|
||||||
save_safetensors: bool | None = Field(
|
save_safetensors: bool | None = Field(
|
||||||
default=True,
|
default=True,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user