FEAT: add tagging support to axolotl (#1004)

* add tagging support to axolotl

* chore: lint

* fix method w self

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Younes Belkada
2023-12-27 23:25:20 +01:00
committed by GitHub
parent 6ef46f8dca
commit db9094df0f

View File

@@ -9,7 +9,7 @@ import math
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial, wraps
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer):
""" """
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs self.num_epochs = num_epochs
@@ -290,12 +291,41 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)
return super().push_to_hub(*args, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
""" """
Mamba specific trainer to handle loss calculation Mamba specific trainer to handle loss calculation
""" """
tag_names = ["axolotl", "mamba"]
def compute_loss( def compute_loss(
self, self,
model, model,
@@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
""" """
tag_names = ["axolotl", "onecycle"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lr_scheduler = None self.lr_scheduler = None
@@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
""" """
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lr_scheduler = None self.lr_scheduler = None