more fixes for dataloader integration

This commit is contained in:
Wing Lian
2023-07-18 10:50:40 -04:00
parent 762f1b08db
commit 41d4992029
2 changed files with 14 additions and 14 deletions

View File

@@ -231,7 +231,7 @@ def train(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")

View File

@@ -91,12 +91,14 @@ class AxolotlTrainer(Trainer):
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
)
)
return super().get_train_dataloader()
@@ -157,7 +159,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset,
num_replicas=1,
rank=0,
seed=cfg.seed,
seed=cfg.seed or 42,
)
data_loader = MultipackDistributedDataloader(
train_dataset,
@@ -170,12 +172,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
),
sampler=sampler,
)
data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.ceil(
len(data_loader)
* cfg.micro_batch_size
* cfg.num_epochs
/ cfg.batch_size
data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size
)
)
else:
@@ -262,8 +263,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps
* cfg.num_epochs, # this is helpful in case we don't actually know total # of steps
max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None