fix axolotl training args dataclass annotation

This commit is contained in:
Wing Lian
2023-07-17 04:57:02 -04:00
parent 73e70e3996
commit ebaec3c406

View File

@@ -5,7 +5,7 @@ import logging
import math
import os
import sys
from dataclasses import field
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
@@ -29,6 +29,7 @@ from axolotl.utils.schedulers import (
LOG = logging.getLogger("axolotl")
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
@@ -188,7 +189,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments(
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None