chore: lint

This commit is contained in:
Wing Lian
2025-01-10 02:18:55 -05:00
parent 9db5072407
commit 73b6b0a580

View File

@@ -13,7 +13,6 @@ from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
import torch
import transformers
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
@@ -76,7 +75,7 @@ class SchedulerMixin(Trainer):
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: AxolotlTrainingArguments
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
@@ -162,7 +161,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
@@ -202,12 +201,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
and self.args.embedding_lr is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()