From ebaec3c406c7a0670ff1ee911e0abc70972b3f67 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 17 Jul 2023 04:57:02 -0400 Subject: [PATCH] fix axolotl training args dataclass annotation --- src/axolotl/utils/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index bdd760526..2144e6b02 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,7 +5,7 @@ import logging import math import os import sys -from dataclasses import field +from dataclasses import dataclass, field from pathlib import Path from typing import Optional @@ -29,6 +29,7 @@ from axolotl.utils.schedulers import ( LOG = logging.getLogger("axolotl") +@dataclass class AxolotlTrainingArguments(TrainingArguments): """ Extend the base TrainingArguments for axolotl helpers @@ -188,7 +189,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.save_safetensors: training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors - training_args = AxolotlTrainingArguments( + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None