fix axolotl training args dataclass annotation

This commit is contained in:
Wing Lian
2023-07-17 04:57:02 -04:00
parent 73e70e3996
commit ebaec3c406

View File

@@ -5,7 +5,7 @@ import logging
import math import math
import os import os
import sys import sys
from dataclasses import field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -29,6 +29,7 @@ from axolotl.utils.schedulers import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@dataclass
class AxolotlTrainingArguments(TrainingArguments): class AxolotlTrainingArguments(TrainingArguments):
""" """
Extend the base TrainingArguments for axolotl helpers Extend the base TrainingArguments for axolotl helpers
@@ -188,7 +189,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.save_safetensors: if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None if cfg.eval_batch_size is not None