more fixes for dataloader integration
This commit is contained in:
@@ -231,7 +231,7 @@ def train(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
|
||||
@@ -91,12 +91,14 @@ class AxolotlTrainer(Trainer):
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
|
||||
return MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
@@ -157,7 +159,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
train_dataset,
|
||||
num_replicas=1,
|
||||
rank=0,
|
||||
seed=cfg.seed,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
@@ -170,12 +172,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
),
|
||||
sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
len(data_loader)
|
||||
* cfg.micro_batch_size
|
||||
* cfg.num_epochs
|
||||
/ cfg.batch_size
|
||||
data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -262,8 +263,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps
|
||||
* cfg.num_epochs, # this is helpful in case we don't actually know total # of steps
|
||||
max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
|
||||
Reference in New Issue
Block a user