Fix ORPO multi gpu (#1433)
* don't drop attention_mask for orpo * handle multi-gpu cases better for orpo * revert change to not drop the attention_mask from inputs for orpo
This commit is contained in:
@@ -30,6 +30,7 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
@@ -472,6 +473,58 @@ class AxolotlTrainer(Trainer):
|
|||||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
|
concatenated_batch = {}
|
||||||
|
|
||||||
|
max_length = max(
|
||||||
|
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
||||||
|
)
|
||||||
|
# Concatenate positive and negative inputs
|
||||||
|
concatenated_batch["input_ids"] = pad_to_length(
|
||||||
|
inputs["input_ids"], max_length, pad_token
|
||||||
|
)
|
||||||
|
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
||||||
|
inputs["rejected_input_ids"], max_length, pad_token
|
||||||
|
)
|
||||||
|
concatenated_batch["labels"] = pad_to_length(
|
||||||
|
inputs["labels"], max_length, label_pad_token
|
||||||
|
)
|
||||||
|
concatenated_batch["rejected_labels"] = pad_to_length(
|
||||||
|
inputs["rejected_labels"], max_length, label_pad_token
|
||||||
|
)
|
||||||
|
concatenated_batch["attention_mask"] = pad_to_length(
|
||||||
|
inputs["attention_mask"], max_length, 0
|
||||||
|
)
|
||||||
|
concatenated_batch["rejected_attention_mask"] = pad_to_length(
|
||||||
|
inputs["rejected_attention_mask"], max_length, 0
|
||||||
|
)
|
||||||
|
concatenated_batch["prompt_attention_mask"] = pad_to_length(
|
||||||
|
inputs["prompt_attention_mask"], max_length, 0
|
||||||
|
).to(device=device)
|
||||||
|
|
||||||
|
input_ids = torch.cat(
|
||||||
|
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
||||||
|
dim=0,
|
||||||
|
).to(device=device)
|
||||||
|
attention_mask = torch.cat(
|
||||||
|
[
|
||||||
|
concatenated_batch["attention_mask"],
|
||||||
|
concatenated_batch["rejected_attention_mask"],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).to(device=device)
|
||||||
|
labels = torch.cat(
|
||||||
|
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
|
||||||
|
).to(device=device)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
|
||||||
|
}
|
||||||
|
|
||||||
def orpo_compute_custom_loss(self, logits, labels):
|
def orpo_compute_custom_loss(self, logits, labels):
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@@ -512,45 +565,46 @@ class AxolotlTrainer(Trainer):
|
|||||||
dim=2,
|
dim=2,
|
||||||
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
||||||
).squeeze(2)
|
).squeeze(2)
|
||||||
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
|
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||||
dtype=torch.float64
|
|
||||||
) / mask.sum(dim=1).to(dtype=torch.float64)
|
|
||||||
|
|
||||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
||||||
outputs_neg = model(
|
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||||
|
inputs,
|
||||||
|
label_pad_token=-100,
|
||||||
|
pad_token=self.tokenizer.pad_token_id,
|
||||||
|
device=self.accelerator.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform a single forward pass
|
||||||
|
outputs = model(
|
||||||
**{
|
**{
|
||||||
"input_ids": inputs["rejected_input_ids"],
|
"input_ids": concat_inputs["input_ids"],
|
||||||
"attention_mask": inputs["rejected_attention_mask"],
|
"attention_mask": concat_inputs["attention_mask"],
|
||||||
"labels": inputs["rejected_labels"],
|
"labels": concat_inputs["labels"],
|
||||||
},
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
outputs_pos = model(
|
|
||||||
**{
|
|
||||||
"input_ids": inputs["input_ids"],
|
|
||||||
"attention_mask": inputs["attention_mask"],
|
|
||||||
"labels": inputs["labels"],
|
|
||||||
},
|
},
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Split the outputs for positive and negative examples
|
||||||
|
outputs_pos, outputs_neg = outputs.logits.chunk(2)
|
||||||
|
|
||||||
# Calculate NLL loss
|
# Calculate NLL loss
|
||||||
pos_loss = self.orpo_compute_custom_loss(
|
pos_loss = self.orpo_compute_custom_loss(
|
||||||
logits=outputs_pos.logits, labels=inputs["input_ids"]
|
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate Log Probability
|
# Calculate Log Probability
|
||||||
pos_prob = self.orpo_compute_logps(
|
pos_prob = self.orpo_compute_logps(
|
||||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
||||||
chosen_inputs=inputs["input_ids"],
|
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
|
||||||
chosen_attention_mask=inputs["attention_mask"],
|
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
|
||||||
logits=outputs_pos.logits,
|
logits=outputs_pos,
|
||||||
)
|
)
|
||||||
neg_prob = self.orpo_compute_logps(
|
neg_prob = self.orpo_compute_logps(
|
||||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
||||||
chosen_inputs=inputs["rejected_input_ids"],
|
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
|
||||||
chosen_attention_mask=inputs["rejected_attention_mask"],
|
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
|
||||||
logits=outputs_neg.logits,
|
logits=outputs_neg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate log odds
|
# Calculate log odds
|
||||||
@@ -1247,6 +1301,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||||
eval_data_collator=self.build_collator(
|
eval_data_collator=self.build_collator(
|
||||||
training_args, is_eval=True, **data_collator_kwargs
|
training_args, is_eval=True, **data_collator_kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user