updates
This commit is contained in:
@@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Build the HuggingFace training args/trainer for causal models
|
Build the HuggingFace training args/trainer for causal models and reward modeling
|
||||||
and reward modelling using TRL.
|
using TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from peft import PeftConfig, PeftModel
|
|||||||
from pkg_resources import get_distribution # type: ignore
|
from pkg_resources import get_distribution # type: ignore
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||||
@@ -278,7 +279,7 @@ def save_trained_model(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
def create_model_card(cfg: DictDefault, trainer: Any):
|
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||||
"""
|
"""
|
||||||
Create a model card for the trained model if needed.
|
Create a model card for the trained model if needed.
|
||||||
|
|
||||||
@@ -287,6 +288,7 @@ def create_model_card(cfg: DictDefault, trainer: Any):
|
|||||||
trainer: The trainer object with model card creation capabilities
|
trainer: The trainer object with model card creation capabilities
|
||||||
"""
|
"""
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
# Guard since create_model_card may fail if dataset_tags is empty list
|
||||||
try:
|
try:
|
||||||
model_card_kwarg = {
|
model_card_kwarg = {
|
||||||
"model_name": cfg.output_dir.lstrip("./")
|
"model_name": cfg.output_dir.lstrip("./")
|
||||||
@@ -294,26 +296,22 @@ def create_model_card(cfg: DictDefault, trainer: Any):
|
|||||||
.decode("utf-8")
|
.decode("utf-8")
|
||||||
}
|
}
|
||||||
if cfg.datasets is not None:
|
if cfg.datasets is not None:
|
||||||
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
|
dataset_tags = [
|
||||||
dataset_tags = [
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
]
|
||||||
]
|
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
|
||||||
dataset_tags = [
|
|
||||||
d for d in dataset_tags if not d.startswith("https://")
|
if dataset_tags:
|
||||||
]
|
param_name = (
|
||||||
if dataset_tags:
|
"dataset_name"
|
||||||
# guard as create_model_card may fail if dataset_tags is empty list
|
if (
|
||||||
model_card_kwarg["dataset_name"] = dataset_tags
|
cfg.rl is not None
|
||||||
else:
|
or cfg.reward_model
|
||||||
dataset_tags = [
|
or cfg.process_reward_model
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
)
|
||||||
]
|
else "dataset_tags"
|
||||||
dataset_tags = [
|
)
|
||||||
d for d in dataset_tags if not d.startswith("https://")
|
model_card_kwarg[param_name] = dataset_tags
|
||||||
]
|
|
||||||
if dataset_tags:
|
|
||||||
# guard as create_model_card may fail if dataset_tags is empty list
|
|
||||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
except (AttributeError, UnicodeDecodeError):
|
||||||
|
|||||||
Reference in New Issue
Block a user