fix: 'NoneType' object has no attribute 'column_names' (#2822) [skip ci]

* fix: 'NoneType' object has no attribute 'column_names'

* chore: typing
This commit is contained in:
NanoCode012
2025-06-25 20:49:55 +07:00
committed by GitHub
parent a27c4f8771
commit 20106116da

View File

@@ -20,7 +20,7 @@ from torch.utils.data import (
SequentialSampler,
)
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -122,8 +122,8 @@ class AxolotlTrainer(
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
self, train_dataset: Dataset | None = None
) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential).
@@ -132,16 +132,22 @@ class AxolotlTrainer(
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
if train_dataset is None:
train_dataset = self.train_dataset
if train_dataset is None or not has_length(train_dataset):
return None
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
base_sampler = SequentialSampler(train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
base_sampler = RandomSampler(train_dataset)
else:
# Default to parent class implementation for standard random sampling
return super()._get_train_sampler()
return super()._get_train_sampler(train_dataset)
# Apply multipack wrapper if needed
if use_sample_packing:
@@ -160,6 +166,10 @@ class AxolotlTrainer(
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
if eval_dataset is None or not has_length(eval_dataset):
return None
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False