chore: lint
This commit is contained in:
@@ -13,7 +13,6 @@ from functools import wraps
|
|||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -76,7 +75,7 @@ class SchedulerMixin(Trainer):
|
|||||||
Mixin class for scheduler setup in CausalTrainer.
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
@@ -162,7 +161,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -202,12 +201,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
and self.args.embedding_lr is None
|
and self.args.embedding_lr is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in [
|
not in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
"adopt_adamw",
|
"adopt_adamw",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user