precommit fixes

This commit is contained in:
Dan Saunders
2025-03-11 14:24:48 +00:00
parent 1d339e4007
commit 698e599bf7
3 changed files with 29 additions and 71 deletions

View File

@@ -413,13 +413,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
generator = None
if self.args.sequence_parallel_size > 1:
generator = torch.Generator()
generator.manual_seed(self.args.seed)
sampler = RandomSampler(
self.train_dataset, generator=generator
)
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
@@ -434,13 +428,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
sampler = super()._get_train_sampler()
if self.args.sequence_parallel_size > 1:
generator = torch.Generator()
generator.manual_seed(self.args.seed)
sampler.generator = generator
return sampler
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
@@ -498,28 +486,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
self.accelerator.even_batches = False
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
else:
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
else:
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
else:
return self.accelerator.prepare(
DataLoader(train_dataset, **dataloader_params)
)
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if self.args.sequence_parallel_size > 1:
return DataLoader(train_dataset, **dataloader_params)
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
@@ -842,7 +832,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int = None,
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.

View File

@@ -218,7 +218,9 @@ class DataCollatorForSeq2Seq:
f"Non-padding tokens: {non_pad_tokens_total}"
)
logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}")
logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}")
logger.info(
f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}"
)
batch[key] = batch[key][:, start_idx:end_idx]