fix: 'NoneType' object has no attribute 'column_names' (#2822) [skip ci]

* fix: 'NoneType' object has no attribute 'column_names'

* chore: typing
This commit is contained in:
NanoCode012
2025-06-25 20:49:55 +07:00
committed by GitHub
parent a27c4f8771
commit 20106116da

View File

@@ -20,7 +20,7 @@ from torch.utils.data import (
SequentialSampler, SequentialSampler,
) )
from transformers import Trainer from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from typing_extensions import override from typing_extensions import override
@@ -122,8 +122,8 @@ class AxolotlTrainer(
return sampler return sampler
def _get_train_sampler( def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None self, train_dataset: Dataset | None = None
) -> Optional[Sampler]: ) -> Sampler | None:
""" """
Helper method to get the sampler for training. Handles cases for sample packing Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential). and curriculum sampling (sequential).
@@ -132,16 +132,22 @@ class AxolotlTrainer(
If the dataset is non-empty, a sampler is returned, the type of which If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args. depends on the passed training args.
""" """
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
if train_dataset is None:
train_dataset = self.train_dataset
if train_dataset is None or not has_length(train_dataset):
return None
use_sample_packing = self.args.sample_packing and not self.args.pretraining use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first # Determine the base sampler first
if self.args.curriculum_sampling: if self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset) base_sampler = SequentialSampler(train_dataset)
elif use_sample_packing: elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset) base_sampler = RandomSampler(train_dataset)
else: else:
# Default to parent class implementation for standard random sampling # Default to parent class implementation for standard random sampling
return super()._get_train_sampler() return super()._get_train_sampler(train_dataset)
# Apply multipack wrapper if needed # Apply multipack wrapper if needed
if use_sample_packing: if use_sample_packing:
@@ -160,6 +166,10 @@ class AxolotlTrainer(
If the dataset is non-empty, a sampler is returned, the type of which If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args. depends on the passed training args.
""" """
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
if eval_dataset is None or not has_length(eval_dataset):
return None
# Multipacking enabled if training is enabled and eval is not explicitly disabled # Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = ( use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False self.args.sample_packing and self.args.eval_sample_packing is not False