chore: lint
This commit is contained in:
@@ -13,7 +13,6 @@ from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
@@ -76,7 +75,7 @@ class SchedulerMixin(Trainer):
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
@@ -162,7 +161,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
@@ -202,12 +201,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user