fix: 'NoneType' object has no attribute 'column_names' (#2822) [skip ci]
* fix: 'NoneType' object has no attribute 'column_names' * chore: typing
This commit is contained in:
@@ -20,7 +20,7 @@ from torch.utils.data import (
|
|||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -122,8 +122,8 @@ class AxolotlTrainer(
|
|||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
def _get_train_sampler(
|
def _get_train_sampler(
|
||||||
self, train_dataset: Optional[Dataset] = None
|
self, train_dataset: Dataset | None = None
|
||||||
) -> Optional[Sampler]:
|
) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sample packing
|
Helper method to get the sampler for training. Handles cases for sample packing
|
||||||
and curriculum sampling (sequential).
|
and curriculum sampling (sequential).
|
||||||
@@ -132,16 +132,22 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
||||||
|
if train_dataset is None:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
if train_dataset is None or not has_length(train_dataset):
|
||||||
|
return None
|
||||||
|
|
||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
base_sampler = SequentialSampler(train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
base_sampler = RandomSampler(self.train_dataset)
|
base_sampler = RandomSampler(train_dataset)
|
||||||
else:
|
else:
|
||||||
# Default to parent class implementation for standard random sampling
|
# Default to parent class implementation for standard random sampling
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler(train_dataset)
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
# Apply multipack wrapper if needed
|
||||||
if use_sample_packing:
|
if use_sample_packing:
|
||||||
@@ -160,6 +166,10 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
||||||
|
if eval_dataset is None or not has_length(eval_dataset):
|
||||||
|
return None
|
||||||
|
|
||||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
use_multipack = (
|
use_multipack = (
|
||||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
|
|||||||
Reference in New Issue
Block a user