seq_len_multiple for packing
This commit is contained in:
@@ -128,6 +128,7 @@ class MultipackDistributedDataloader:
|
|||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
sampler: Union[Sampler, DistributedSampler] = None,
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
|
seq_len_multiple: int = 1,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -135,9 +136,10 @@ class MultipackDistributedDataloader:
|
|||||||
[len(sample["input_ids"]) for sample in self.dataset]
|
[len(sample["input_ids"]) for sample in self.dataset]
|
||||||
)
|
)
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
assert batch_size % seq_len_multiple == 0
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.seq_len_multiple = seq_len_multiple
|
||||||
self.seq_max_length = seq_max_length
|
self.seq_max_length = seq_max_length
|
||||||
self.batch_max_length = batch_size * seq_max_length
|
self.batch_max_length = batch_size * seq_max_length
|
||||||
self.collate_fn = collate_fn
|
self.collate_fn = collate_fn
|
||||||
@@ -148,7 +150,7 @@ class MultipackDistributedDataloader:
|
|||||||
# statistics
|
# statistics
|
||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
@@ -164,7 +166,7 @@ class MultipackDistributedDataloader:
|
|||||||
lengths_cumsum=lengths_cumsum,
|
lengths_cumsum=lengths_cumsum,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
# c=self.batch_max_length,
|
# c=self.batch_max_length,
|
||||||
c=self.seq_max_length,
|
c=self.seq_max_length * self.seq_len_multiple,
|
||||||
n=self.num_replicas,
|
n=self.num_replicas,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -181,18 +183,20 @@ class MultipackDistributedDataloader:
|
|||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
for batches in chunk(all_batches, self.batch_size):
|
for batches in chunk(all_batches, self.batch_size / self.seq_len_multiple):
|
||||||
chunked_data = []
|
chunked_data = []
|
||||||
|
attn_mask_cum_idx = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
concatenated = {}
|
concatenated = {}
|
||||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
for feature in features:
|
for feature in features:
|
||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
arrays = [
|
arrays = [
|
||||||
(idx + 1) * np.array(item[feature])
|
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||||
for idx, item in enumerate(batched_data)
|
for idx, item in enumerate(batched_data)
|
||||||
if feature in item
|
if feature in item
|
||||||
]
|
]
|
||||||
|
attn_mask_cum_idx += len(batched_data)
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
@@ -216,7 +220,7 @@ class MultipackDistributedDataloader:
|
|||||||
# }
|
# }
|
||||||
# chunked_data.append(chunk)
|
# chunked_data.append(chunk)
|
||||||
# yield self.collate_fn(chunked_data)
|
# yield self.collate_fn(chunked_data)
|
||||||
yield self.collate_fn([chunked_data])
|
yield self.collate_fn(chunked_data)
|
||||||
len_remaining -= 1
|
len_remaining -= 1
|
||||||
if not len_remaining:
|
if not len_remaining:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -53,26 +53,34 @@ def weighted_cross_entropy(
|
|||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def create_weighted_mask(labels: torch.Tensor):
|
def create_weighted_mask(labels: torch.Tensor):
|
||||||
mask = labels != -100
|
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
|
||||||
|
if len(labels.shape) == 1:
|
||||||
|
labels = labels.unsqueeze(0)
|
||||||
|
|
||||||
# Create a tensor to track group ids
|
weights = torch.zeros_like(labels).float()
|
||||||
group_ids = torch.zeros_like(labels).int()
|
for i in range(labels.shape[0]):
|
||||||
curr_group_id = 0
|
mask = labels[i] != -100
|
||||||
|
|
||||||
for i in range(1, len(labels)):
|
# Create a tensor to track group ids
|
||||||
if mask[i] and not mask[i - 1]: # switch from masked to unmasked label
|
group_ids = torch.zeros_like(labels[i]).int()
|
||||||
curr_group_id += 1 # start new group
|
curr_group_id = 0
|
||||||
group_ids[i] = (
|
|
||||||
curr_group_id if mask[i] else 0
|
|
||||||
) # assign group id if unmasked label
|
|
||||||
|
|
||||||
# Count only unmasked labels in each group
|
for j in range(1, len(labels[i])):
|
||||||
group_counts = torch.bincount(group_ids[mask])
|
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
|
||||||
|
curr_group_id += 1 # start new group
|
||||||
|
group_ids[j] = (
|
||||||
|
curr_group_id if mask[j] else 0
|
||||||
|
) # assign group id if unmasked label
|
||||||
|
|
||||||
mask_weights = torch.zeros_like(labels).float()
|
# Count only unmasked labels in each group
|
||||||
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
group_counts = torch.bincount(group_ids[mask])
|
||||||
|
|
||||||
return mask_weights
|
mask_weights = torch.zeros_like(labels[i]).float()
|
||||||
|
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||||
|
|
||||||
|
weights[i] = mask_weights
|
||||||
|
|
||||||
|
return weights.squeeze() # squeeze the output to match the input dimension
|
||||||
|
|
||||||
|
|
||||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||||
@@ -168,6 +176,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
seq_len_multiple=2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -187,16 +196,18 @@ class AxolotlTrainer(Trainer):
|
|||||||
seq_max_length=self.args.max_seq_length,
|
seq_max_length=self.args.max_seq_length,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
sampler=eval_sampler,
|
sampler=eval_sampler,
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
seq_len_multiple=2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
if self.args.sample_packing:
|
# if self.args.sample_packing:
|
||||||
labels = inputs.pop("labels")
|
# labels = inputs.pop("labels")
|
||||||
outputs = model(**inputs)
|
# outputs = model(**inputs)
|
||||||
loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
|
||||||
@@ -262,7 +273,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
* total_num_tokens
|
* total_num_tokens
|
||||||
/ cfg.sample_packing_eff_est
|
/ cfg.sample_packing_eff_est
|
||||||
/ 2048
|
/ 2048
|
||||||
/ cfg.batch_size
|
// cfg.batch_size
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
@@ -284,6 +295,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
),
|
),
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
|
seq_len_multiple=2,
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader)
|
data_loader_len = len(data_loader)
|
||||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||||
|
|||||||
Reference in New Issue
Block a user