fixes and go back to distributed sampler since batch sampler won't work
This commit is contained in:
@@ -427,7 +427,7 @@ def load_prepare_datasets(
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ "train"
|
+ "train"
|
||||||
+ "|"
|
+ "|"
|
||||||
+ cfg.seed
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
to_hash_test = (
|
to_hash_test = (
|
||||||
dataset._fingerprint # pylint: disable=protected-access
|
dataset._fingerprint # pylint: disable=protected-access
|
||||||
@@ -436,7 +436,7 @@ def load_prepare_datasets(
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ "test"
|
+ "test"
|
||||||
+ "|"
|
+ "|"
|
||||||
+ cfg.seed
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
train_fingerprint = hashlib.md5(
|
train_fingerprint = hashlib.md5(
|
||||||
to_hash_train.encode(), usedforsecurity=False
|
to_hash_train.encode(), usedforsecurity=False
|
||||||
@@ -449,7 +449,7 @@ def load_prepare_datasets(
|
|||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed or 42,
|
||||||
train_new_fingerprint=train_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
test_new_fingerprint=test_fingerprint,
|
test_new_fingerprint=test_fingerprint,
|
||||||
)
|
)
|
||||||
@@ -458,7 +458,7 @@ def load_prepare_datasets(
|
|||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed or 42,
|
||||||
train_new_fingerprint=train_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
test_new_fingerprint=test_fingerprint,
|
test_new_fingerprint=test_fingerprint,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import transformers
|
|||||||
from datasets import Dataset, set_caching_enabled
|
from datasets import Dataset, set_caching_enabled
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
@@ -159,23 +159,20 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
# def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
# if self.args.world_size > 1 and self.args.sample_packing:
|
if self.args.world_size > 1 and self.args.sample_packing:
|
||||||
# return DistributedSampler(
|
return DistributedSampler(
|
||||||
# self.train_dataset,
|
self.train_dataset,
|
||||||
# num_replicas=self.args.world_size,
|
num_replicas=self.args.world_size,
|
||||||
# rank=self.args.process_index,
|
rank=self.args.process_index,
|
||||||
# seed=self.args.seed,
|
seed=self.args.seed,
|
||||||
# )
|
)
|
||||||
# return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_sampler = self._get_train_sampler()
|
||||||
# If set to True, the dataloader prepared is only iterated through on the
|
return self.accelerator.prepare(
|
||||||
# main process and then the batches are split and broadcast to each process
|
|
||||||
self.accelerator.dispatch_batches = True
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
MultipackDistributedDataloader(
|
MultipackDistributedDataloader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size=self._train_batch_size,
|
batch_size=self._train_batch_size,
|
||||||
@@ -197,10 +194,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
# If set to True, the datalaoder prepared is only iterated through on the
|
return self.accelerator.prepare(
|
||||||
# main process and then the batches are split and broadcast to each process
|
|
||||||
self.accelerator.dispatch_batches = True
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
MultipackDistributedDataloader(
|
MultipackDistributedDataloader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user