fix axolotl training args dataclass annotation
This commit is contained in:
@@ -5,7 +5,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -29,6 +29,7 @@ from axolotl.utils.schedulers import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingArguments(TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Extend the base TrainingArguments for axolotl helpers
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
@@ -188,7 +189,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.save_safetensors:
|
if cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments(
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
if cfg.eval_batch_size is not None
|
if cfg.eval_batch_size is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user