seq_len_multiple for packing

This commit is contained in:
Wing Lian
2023-08-02 23:20:19 -04:00
parent 8378335dc9
commit 98c9bc69de
2 changed files with 43 additions and 27 deletions

View File

@@ -128,6 +128,7 @@ class MultipackDistributedDataloader:
batch_size: int = 1, batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None, sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
seq_len_multiple: int = 1,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -135,9 +136,10 @@ class MultipackDistributedDataloader:
[len(sample["input_ids"]) for sample in self.dataset] [len(sample["input_ids"]) for sample in self.dataset]
) )
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
assert batch_size % seq_len_multiple == 0
self.sampler = sampler self.sampler = sampler
self.batch_size = batch_size self.batch_size = batch_size
self.seq_len_multiple = seq_len_multiple
self.seq_max_length = seq_max_length self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn self.collate_fn = collate_fn
@@ -148,7 +150,7 @@ class MultipackDistributedDataloader:
# statistics # statistics
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
if self.sampler: if self.sampler:
@@ -164,7 +166,7 @@ class MultipackDistributedDataloader:
lengths_cumsum=lengths_cumsum, lengths_cumsum=lengths_cumsum,
rank=self.rank, rank=self.rank,
# c=self.batch_max_length, # c=self.batch_max_length,
c=self.seq_max_length, c=self.seq_max_length * self.seq_len_multiple,
n=self.num_replicas, n=self.num_replicas,
) )
@@ -181,18 +183,20 @@ class MultipackDistributedDataloader:
all_batches, _ = self.generate_batches(set_stats=True) all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()
for batches in chunk(all_batches, self.batch_size): for batches in chunk(all_batches, self.batch_size / self.seq_len_multiple):
chunked_data = [] chunked_data = []
attn_mask_cum_idx = 0
for batch in batches: for batch in batches:
concatenated = {} concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch] batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features: for feature in features:
if feature == "attention_mask": if feature == "attention_mask":
arrays = [ arrays = [
(idx + 1) * np.array(item[feature]) (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data) for idx, item in enumerate(batched_data)
if feature in item if feature in item
] ]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays) concatenated[feature] = np.concatenate(arrays)
else: else:
arrays = [ arrays = [
@@ -216,7 +220,7 @@ class MultipackDistributedDataloader:
# } # }
# chunked_data.append(chunk) # chunked_data.append(chunk)
# yield self.collate_fn(chunked_data) # yield self.collate_fn(chunked_data)
yield self.collate_fn([chunked_data]) yield self.collate_fn(chunked_data)
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
return return

View File

@@ -53,26 +53,34 @@ def weighted_cross_entropy(
@torch.jit.script @torch.jit.script
def create_weighted_mask(labels: torch.Tensor): def create_weighted_mask(labels: torch.Tensor):
mask = labels != -100 # Check if the tensor is 2D. If not, unsqueeze it to make it 2D
if len(labels.shape) == 1:
labels = labels.unsqueeze(0)
# Create a tensor to track group ids weights = torch.zeros_like(labels).float()
group_ids = torch.zeros_like(labels).int() for i in range(labels.shape[0]):
curr_group_id = 0 mask = labels[i] != -100
for i in range(1, len(labels)): # Create a tensor to track group ids
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label group_ids = torch.zeros_like(labels[i]).int()
curr_group_id += 1 # start new group curr_group_id = 0
group_ids[i] = (
curr_group_id if mask[i] else 0
) # assign group id if unmasked label
# Count only unmasked labels in each group for j in range(1, len(labels[i])):
group_counts = torch.bincount(group_ids[mask]) if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[j] = (
curr_group_id if mask[j] else 0
) # assign group id if unmasked label
mask_weights = torch.zeros_like(labels).float() # Count only unmasked labels in each group
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] group_counts = torch.bincount(group_ids[mask])
return mask_weights mask_weights = torch.zeros_like(labels[i]).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
weights[i] = mask_weights
return weights.squeeze() # squeeze the output to match the input dimension
def trainer_weighted_loss(model_output, labels, shift_labels=True): def trainer_weighted_loss(model_output, labels, shift_labels=True):
@@ -168,6 +176,7 @@ class AxolotlTrainer(Trainer):
collate_fn=self.data_collator, collate_fn=self.data_collator,
sampler=train_sampler, sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
seq_len_multiple=2,
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -187,16 +196,18 @@ class AxolotlTrainer(Trainer):
seq_max_length=self.args.max_seq_length, seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator, collate_fn=self.data_collator,
sampler=eval_sampler, sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
seq_len_multiple=2,
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
if self.args.sample_packing: # if self.args.sample_packing:
labels = inputs.pop("labels") # labels = inputs.pop("labels")
outputs = model(**inputs) # outputs = model(**inputs)
loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
@@ -262,7 +273,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
* total_num_tokens * total_num_tokens
/ cfg.sample_packing_eff_est / cfg.sample_packing_eff_est
/ 2048 / 2048
/ cfg.batch_size // cfg.batch_size
) )
- 1 - 1
) )
@@ -284,6 +295,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
), ),
sampler=sampler, sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est, packing_efficiency_estimate=cfg.sample_packing_eff_est,
seq_len_multiple=2,
) )
data_loader_len = len(data_loader) data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}") LOG.info(f"data_loader_len: {data_loader_len}")