|
|
|
|
@@ -1,5 +1,7 @@
|
|
|
|
|
"""Axolotl GRPO trainer"""
|
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-lines,duplicate-code
|
|
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
@@ -68,6 +70,7 @@ from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
|
|
|
|
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
|
|
|
|
|
|
|
|
|
|
if is_peft_available():
|
|
|
|
|
# pylint: disable=unused-import
|
|
|
|
|
from peft import PeftConfig, get_peft_model
|
|
|
|
|
|
|
|
|
|
if is_deepspeed_available():
|
|
|
|
|
@@ -101,6 +104,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
] = (None, None),
|
|
|
|
|
peft_config: "PeftConfig | None" = None,
|
|
|
|
|
):
|
|
|
|
|
RngLoaderMixin.__init__(self)
|
|
|
|
|
SchedulerMixin.__init__(self)
|
|
|
|
|
|
|
|
|
|
# Args
|
|
|
|
|
if args is None:
|
|
|
|
|
model_name = model if isinstance(model, str) else model.config._name_or_path
|
|
|
|
|
@@ -264,7 +270,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
model.warnings_issued["estimate_tokens"] = True
|
|
|
|
|
|
|
|
|
|
# Initialize the metrics
|
|
|
|
|
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
|
|
|
|
self._metrics: dict[str, dict[str, list]] = {
|
|
|
|
|
"train": defaultdict(list),
|
|
|
|
|
"eval": defaultdict(list),
|
|
|
|
|
}
|
|
|
|
|
self._total_train_tokens = 0
|
|
|
|
|
self.log_completions = args.log_completions
|
|
|
|
|
|
|
|
|
|
@@ -437,13 +446,13 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
return SequenceParallelRepeatRandomSampler(
|
|
|
|
|
dataset=self.train_dataset,
|
|
|
|
|
mini_repeat_count=self.num_generations,
|
|
|
|
|
world_size=world_size,
|
|
|
|
|
rank=rank,
|
|
|
|
|
batch_size=effective_batch_size
|
|
|
|
|
// self.num_generations
|
|
|
|
|
// self.args.sequence_parallel_degree,
|
|
|
|
|
repeat_count=self.num_iterations,
|
|
|
|
|
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
|
|
|
|
world_size=world_size,
|
|
|
|
|
rank=rank,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
seed=self.args.seed,
|
|
|
|
|
drop_last=True,
|
|
|
|
|
@@ -514,6 +523,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
def get_train_dataloader(self) -> DataLoader:
|
|
|
|
|
"""Get dataloader for training"""
|
|
|
|
|
train_dataset = self.train_dataset
|
|
|
|
|
# pylint: disable=access-member-before-definition
|
|
|
|
|
data_collator = self.data_collator # type: ignore
|
|
|
|
|
|
|
|
|
|
# Initialize SP group attributes if sequence parallelism is enabled
|
|
|
|
|
@@ -532,7 +542,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
train_dataset, description="training"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.data_collator = self._get_collator_with_removed_columns(
|
|
|
|
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
|
|
|
|
data_collator,
|
|
|
|
|
description="training",
|
|
|
|
|
)
|
|
|
|
|
@@ -591,7 +601,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
self.vllm_client.reset_prefix_cache()
|
|
|
|
|
|
|
|
|
|
def _generate_and_score_completions(
|
|
|
|
|
self, inputs: dict[str | torch.Tensor | Any]
|
|
|
|
|
self, inputs: list[dict[str, torch.Tensor | Any]]
|
|
|
|
|
) -> dict[str, torch.Tensor | Any]:
|
|
|
|
|
device = self.accelerator.device
|
|
|
|
|
prompts = [x["prompt"] for x in inputs]
|
|
|
|
|
@@ -606,6 +616,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
padding_side="left",
|
|
|
|
|
add_special_tokens=False,
|
|
|
|
|
)
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
|
|
|
|
|
|
|
|
|
|
prompt_ids, prompt_mask = (
|
|
|
|
|
@@ -682,7 +693,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
completion_ids = pad(
|
|
|
|
|
completion_ids, padding_value=self.processing_class.pad_token_id
|
|
|
|
|
)
|
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
|
else:
|
|
|
|
|
# Regular generation path
|
|
|
|
|
with unwrap_model_for_generation(
|
|
|
|
|
@@ -701,6 +711,8 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
|
|
|
|
completion_ids = prompt_completion_ids[:, prompt_length:]
|
|
|
|
|
|
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
|
|
|
|
|
|
# Mask everything after the first EOS token
|
|
|
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
|
|
|
eos_idx = torch.full(
|
|
|
|
|
@@ -719,105 +731,59 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
) # we only need to compute the logits for the completion tokens
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
if self.args.sequence_parallel_degree > 1:
|
|
|
|
|
# Pad sequence to be divisible by SP degree if needed
|
|
|
|
|
total_seq_len = prompt_completion_ids.shape[1]
|
|
|
|
|
if total_seq_len % self.local_world_size != 0:
|
|
|
|
|
pad_len = self.local_world_size - (
|
|
|
|
|
total_seq_len % self.local_world_size
|
|
|
|
|
)
|
|
|
|
|
pad_token_id = self.processing_class.pad_token_id or 0
|
|
|
|
|
|
|
|
|
|
# Pad input_ids and attention_mask
|
|
|
|
|
padding = torch.full(
|
|
|
|
|
(prompt_completion_ids.shape[0], pad_len),
|
|
|
|
|
pad_token_id,
|
|
|
|
|
dtype=prompt_completion_ids.dtype,
|
|
|
|
|
device=prompt_completion_ids.device,
|
|
|
|
|
)
|
|
|
|
|
prompt_completion_ids = torch.cat(
|
|
|
|
|
[prompt_completion_ids, padding], dim=1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_padding = torch.zeros(
|
|
|
|
|
(attention_mask.shape[0], pad_len),
|
|
|
|
|
dtype=attention_mask.dtype,
|
|
|
|
|
device=attention_mask.device,
|
|
|
|
|
)
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
|
|
|
|
|
|
|
|
|
total_seq_len += pad_len
|
|
|
|
|
logits_to_keep += pad_len
|
|
|
|
|
|
|
|
|
|
# Split the sequence
|
|
|
|
|
slice_size = total_seq_len // self.local_world_size
|
|
|
|
|
start = self.local_rank * slice_size
|
|
|
|
|
end = start + slice_size
|
|
|
|
|
|
|
|
|
|
# Get our slice
|
|
|
|
|
prompt_completion_ids = prompt_completion_ids[:, start:end]
|
|
|
|
|
attention_mask = attention_mask[:, start:end]
|
|
|
|
|
|
|
|
|
|
# Calculate how many completion tokens each rank should process
|
|
|
|
|
prompt_len = prompt_ids.size(1)
|
|
|
|
|
completion_len = completion_ids.size(
|
|
|
|
|
1
|
|
|
|
|
) # This is equal to logits_to_keep
|
|
|
|
|
|
|
|
|
|
# Calculate where our slice starts and ends relative to the completion tokens
|
|
|
|
|
if start >= prompt_len:
|
|
|
|
|
# Slice starts within the completion section
|
|
|
|
|
start_in_completion = start - prompt_len
|
|
|
|
|
end_in_completion = min(end - prompt_len, completion_len)
|
|
|
|
|
logits_to_keep = end_in_completion - start_in_completion
|
|
|
|
|
completion_mask = completion_mask[
|
|
|
|
|
:, start_in_completion:end_in_completion
|
|
|
|
|
]
|
|
|
|
|
elif end <= prompt_len:
|
|
|
|
|
# Slice is entirely within the prompt section (no completion tokens)
|
|
|
|
|
logits_to_keep = 0
|
|
|
|
|
completion_mask = torch.zeros(
|
|
|
|
|
(completion_mask.size(0), 0), device=completion_mask.device
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Slice contains the boundary between prompt and completion
|
|
|
|
|
start_in_completion = 0
|
|
|
|
|
end_in_completion = min(end - prompt_len, completion_len)
|
|
|
|
|
logits_to_keep = end_in_completion - start_in_completion
|
|
|
|
|
completion_mask = completion_mask[
|
|
|
|
|
:, start_in_completion:end_in_completion
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
|
|
|
|
# computation here, and use per_token_logps.detach() instead.
|
|
|
|
|
if self.num_iterations > 1:
|
|
|
|
|
old_per_token_logps = self._get_per_token_logps(
|
|
|
|
|
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
|
|
|
|
)
|
|
|
|
|
if self.args.sequence_parallel_degree > 1:
|
|
|
|
|
old_per_token_logps, _ = self._get_per_token_logps_v2(
|
|
|
|
|
self.model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
old_per_token_logps = super()._get_per_token_logps(
|
|
|
|
|
self.model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
old_per_token_logps = None
|
|
|
|
|
|
|
|
|
|
if self.beta == 0.0:
|
|
|
|
|
ref_per_token_logps = None
|
|
|
|
|
elif self.ref_model is not None:
|
|
|
|
|
print(f"{dist.get_rank()}: prompt_completion_ids.shape: {prompt_completion_ids.shape}")
|
|
|
|
|
print(f"{dist.get_rank()}: attention_mask.shape: {attention_mask.shape}")
|
|
|
|
|
print(f"{dist.get_rank()}: logits_to_keep: {logits_to_keep}")
|
|
|
|
|
|
|
|
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
|
|
|
self.ref_model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
|
|
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
|
|
|
self.model,
|
|
|
|
|
if self.args.sequence_parallel_degree > 1:
|
|
|
|
|
ref_per_token_logps, _ = self._get_per_token_logps_v2(
|
|
|
|
|
self.ref_model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ref_per_token_logps = super()._get_per_token_logps(
|
|
|
|
|
self.ref_model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
|
|
|
|
if self.args.sequence_parallel_degree > 1:
|
|
|
|
|
ref_per_token_logps, _ = self._get_per_token_logps_v2(
|
|
|
|
|
self.model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ref_per_token_logps = super()._get_per_token_logps(
|
|
|
|
|
self.model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Decode the generated completions
|
|
|
|
|
completions_text = self.processing_class.batch_decode(
|
|
|
|
|
@@ -848,6 +814,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
f"reward {reward_func.config._name_or_path.split('/')[-1]}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
reward_func_name = reward_func.__name__
|
|
|
|
|
with profiling_context(self, reward_func_name):
|
|
|
|
|
if isinstance(
|
|
|
|
|
@@ -870,6 +837,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
padding_side="right",
|
|
|
|
|
add_special_tokens=False,
|
|
|
|
|
)
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
|
|
|
|
|
@@ -966,6 +934,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
): # Module instead of PretrainedModel for compat with compiled models
|
|
|
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
|
|
|
else:
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
reward_func_name = reward_func.__name__
|
|
|
|
|
# Only calculate mean for samples where this reward function was applied (non-NaN values)
|
|
|
|
|
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
|
|
|
|
|
@@ -1016,6 +985,115 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
"advantages": advantages,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _get_per_token_logps_v2(
|
|
|
|
|
self, model, input_ids, attention_mask, logits_to_keep, completion_mask=None
|
|
|
|
|
):
|
|
|
|
|
# Pad sequence to be divisible by SP degree if needed
|
|
|
|
|
total_seq_len = input_ids.shape[1]
|
|
|
|
|
if total_seq_len % self.local_world_size != 0:
|
|
|
|
|
pad_len = self.local_world_size - (total_seq_len % self.local_world_size)
|
|
|
|
|
pad_token_id = self.processing_class.pad_token_id or 0
|
|
|
|
|
|
|
|
|
|
# Pad input_ids and attention_mask
|
|
|
|
|
padding = torch.full(
|
|
|
|
|
(input_ids.shape[0], pad_len),
|
|
|
|
|
pad_token_id,
|
|
|
|
|
dtype=input_ids.dtype,
|
|
|
|
|
device=input_ids.device,
|
|
|
|
|
)
|
|
|
|
|
input_ids = torch.cat([input_ids, padding], dim=1)
|
|
|
|
|
|
|
|
|
|
attn_padding = torch.zeros(
|
|
|
|
|
(attention_mask.shape[0], pad_len),
|
|
|
|
|
dtype=attention_mask.dtype,
|
|
|
|
|
device=attention_mask.device,
|
|
|
|
|
)
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
|
|
|
|
if completion_mask is not None:
|
|
|
|
|
completion_mask = torch.cat([completion_mask, attn_padding], dim=1)
|
|
|
|
|
|
|
|
|
|
total_seq_len += pad_len
|
|
|
|
|
logits_to_keep += pad_len
|
|
|
|
|
|
|
|
|
|
# Split the sequence
|
|
|
|
|
slice_size = total_seq_len // self.local_world_size
|
|
|
|
|
start = self.local_rank * slice_size
|
|
|
|
|
end = start + slice_size
|
|
|
|
|
|
|
|
|
|
# Get our slice
|
|
|
|
|
input_ids_slice = input_ids[:, start:end]
|
|
|
|
|
attention_mask_slice = attention_mask[:, start:end]
|
|
|
|
|
|
|
|
|
|
# Calculate where our slice starts and ends relative to the completion tokens
|
|
|
|
|
local_completion_mask = None
|
|
|
|
|
prompt_len = input_ids.size(1) - logits_to_keep
|
|
|
|
|
if start >= prompt_len:
|
|
|
|
|
# Slice starts within the completion section
|
|
|
|
|
start_in_completion = start - prompt_len
|
|
|
|
|
end_in_completion = min(end - prompt_len, logits_to_keep)
|
|
|
|
|
local_logits_to_keep = end_in_completion - start_in_completion
|
|
|
|
|
if completion_mask is not None:
|
|
|
|
|
local_completion_mask = completion_mask[
|
|
|
|
|
:, start_in_completion:end_in_completion
|
|
|
|
|
]
|
|
|
|
|
elif end <= prompt_len:
|
|
|
|
|
# Slice is entirely within the prompt section (no completion tokens)
|
|
|
|
|
local_logits_to_keep = 0
|
|
|
|
|
if completion_mask is not None:
|
|
|
|
|
local_completion_mask = torch.zeros(
|
|
|
|
|
(completion_mask.size(0), 0), device=completion_mask.device
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Slice contains the boundary between prompt and completion
|
|
|
|
|
start_in_completion = 0
|
|
|
|
|
end_in_completion = min(end - prompt_len, logits_to_keep)
|
|
|
|
|
local_logits_to_keep = end_in_completion - start_in_completion
|
|
|
|
|
if completion_mask is not None:
|
|
|
|
|
local_completion_mask = completion_mask[
|
|
|
|
|
:, start_in_completion:end_in_completion
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Get logits with enough context to compute log probs
|
|
|
|
|
logits = model(
|
|
|
|
|
input_ids=input_ids_slice,
|
|
|
|
|
attention_mask=attention_mask_slice,
|
|
|
|
|
logits_to_keep=local_logits_to_keep + 1,
|
|
|
|
|
).logits
|
|
|
|
|
|
|
|
|
|
# Only the last rank that contains completion tokens needs to remove the last logit
|
|
|
|
|
is_last_rank_with_completions = (
|
|
|
|
|
self.local_rank == self.local_world_size - 1 # Last rank overall
|
|
|
|
|
or end
|
|
|
|
|
>= prompt_len
|
|
|
|
|
+ logits_to_keep # Our slice includes the last completion token
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if is_last_rank_with_completions:
|
|
|
|
|
logits = logits[:, :-1]
|
|
|
|
|
if local_completion_mask is not None:
|
|
|
|
|
local_completion_mask = local_completion_mask[:, :-1]
|
|
|
|
|
local_logits_to_keep -= 1
|
|
|
|
|
|
|
|
|
|
if start >= prompt_len:
|
|
|
|
|
# For ranks where slice is all completion tokens,
|
|
|
|
|
# we need to offset to match the logits (which predict the next token)
|
|
|
|
|
offset = 1 # Skip the first token as it's predicted by the last token of the previous rank
|
|
|
|
|
local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep]
|
|
|
|
|
else:
|
|
|
|
|
# For the rank that contains the prompt-completion boundary,
|
|
|
|
|
# we need to take completion tokens only
|
|
|
|
|
offset = prompt_len - start # Where completions start in our slice
|
|
|
|
|
local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep]
|
|
|
|
|
|
|
|
|
|
logits = logits[
|
|
|
|
|
:, -local_logits_to_keep:
|
|
|
|
|
] # Take only logits for completion tokens
|
|
|
|
|
logits = logits / self.temperature
|
|
|
|
|
per_token_logps = selective_log_softmax(logits, local_input_ids)
|
|
|
|
|
|
|
|
|
|
return per_token_logps, local_completion_mask
|
|
|
|
|
|
|
|
|
|
# pylint: disable=unused-argument
|
|
|
|
|
@profiling_decorator
|
|
|
|
|
def compute_loss(
|
|
|
|
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
|
|
|
|
@@ -1029,103 +1107,21 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
inputs["completion_ids"],
|
|
|
|
|
inputs["completion_mask"],
|
|
|
|
|
)
|
|
|
|
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
|
|
|
logits_to_keep = completion_ids.size(1)
|
|
|
|
|
|
|
|
|
|
if self.args.sequence_parallel_degree > 1:
|
|
|
|
|
# Pad sequence to be divisible by SP degree if needed
|
|
|
|
|
total_seq_len = input_ids.shape[1]
|
|
|
|
|
if total_seq_len % self.local_world_size != 0:
|
|
|
|
|
pad_len = self.local_world_size - (
|
|
|
|
|
total_seq_len % self.local_world_size
|
|
|
|
|
)
|
|
|
|
|
pad_token_id = self.processing_class.pad_token_id or 0
|
|
|
|
|
|
|
|
|
|
# Pad input_ids and attention_mask
|
|
|
|
|
padding = torch.full(
|
|
|
|
|
(input_ids.shape[0], pad_len),
|
|
|
|
|
pad_token_id,
|
|
|
|
|
dtype=input_ids.dtype,
|
|
|
|
|
device=input_ids.device,
|
|
|
|
|
)
|
|
|
|
|
input_ids = torch.cat([input_ids, padding], dim=1)
|
|
|
|
|
|
|
|
|
|
attn_padding = torch.zeros(
|
|
|
|
|
(attention_mask.shape[0], pad_len),
|
|
|
|
|
dtype=attention_mask.dtype,
|
|
|
|
|
device=attention_mask.device,
|
|
|
|
|
)
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
|
|
|
|
|
|
|
|
|
total_seq_len += pad_len
|
|
|
|
|
logits_to_keep += pad_len
|
|
|
|
|
|
|
|
|
|
# Split the sequence
|
|
|
|
|
slice_size = total_seq_len // self.local_world_size
|
|
|
|
|
start = self.local_rank * slice_size
|
|
|
|
|
end = start + slice_size
|
|
|
|
|
|
|
|
|
|
# Get our slice
|
|
|
|
|
input_ids_slice = input_ids[:, start:end]
|
|
|
|
|
attention_mask_slice = attention_mask[:, start:end]
|
|
|
|
|
|
|
|
|
|
# Calculate how many completion tokens each rank should process
|
|
|
|
|
prompt_len = prompt_ids.size(1)
|
|
|
|
|
completion_len = completion_ids.size(1) # This is equal to logits_to_keep
|
|
|
|
|
|
|
|
|
|
# Calculate where our slice starts and ends relative to the completion tokens
|
|
|
|
|
if start >= prompt_len:
|
|
|
|
|
# Slice starts within the completion section
|
|
|
|
|
start_in_completion = start - prompt_len
|
|
|
|
|
end_in_completion = min(end - prompt_len, completion_len)
|
|
|
|
|
local_logits_to_keep = end_in_completion - start_in_completion
|
|
|
|
|
completion_mask = completion_mask[
|
|
|
|
|
:, start_in_completion:end_in_completion
|
|
|
|
|
]
|
|
|
|
|
elif end <= prompt_len:
|
|
|
|
|
# Slice is entirely within the prompt section (no completion tokens)
|
|
|
|
|
local_logits_to_keep = 0
|
|
|
|
|
completion_mask = torch.zeros(
|
|
|
|
|
(completion_mask.size(0), 0), device=completion_mask.device
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Slice contains the boundary between prompt and completion
|
|
|
|
|
start_in_completion = 0
|
|
|
|
|
end_in_completion = min(end - prompt_len, completion_len)
|
|
|
|
|
local_logits_to_keep = end_in_completion - start_in_completion
|
|
|
|
|
completion_mask = completion_mask[
|
|
|
|
|
:, start_in_completion:end_in_completion
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Run model on our slice
|
|
|
|
|
if local_logits_to_keep > 0:
|
|
|
|
|
# Get logits with enough context to compute log probs
|
|
|
|
|
logits = model(
|
|
|
|
|
input_ids=input_ids_slice,
|
|
|
|
|
attention_mask=attention_mask_slice,
|
|
|
|
|
logits_to_keep=local_logits_to_keep + 1,
|
|
|
|
|
).logits
|
|
|
|
|
|
|
|
|
|
# Remove the last prediction token on the last rank
|
|
|
|
|
# if self.local_rank == self.local_world_size - 1:
|
|
|
|
|
# logits = logits[:, :-1, :]
|
|
|
|
|
logits = logits[:, :-1, :]
|
|
|
|
|
|
|
|
|
|
# Compute log probabilities on our local slice
|
|
|
|
|
local_input_ids = input_ids_slice[:, -local_logits_to_keep:]
|
|
|
|
|
logits = logits / self.temperature
|
|
|
|
|
per_token_logps = selective_log_softmax(logits, local_input_ids)
|
|
|
|
|
else:
|
|
|
|
|
# This rank doesn't have any tokens to keep
|
|
|
|
|
per_token_logps = torch.zeros(
|
|
|
|
|
(input_ids.shape[0], 0),
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device=input_ids.device,
|
|
|
|
|
)
|
|
|
|
|
per_token_logps, completion_mask = self._get_per_token_logps_v2(
|
|
|
|
|
model,
|
|
|
|
|
prompt_completion_ids,
|
|
|
|
|
attention_mask,
|
|
|
|
|
logits_to_keep,
|
|
|
|
|
completion_mask,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
per_token_logps = super()._get_per_token_logps(
|
|
|
|
|
model, input_ids, attention_mask, logits_to_keep
|
|
|
|
|
model, prompt_completion_ids, attention_mask, logits_to_keep
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Compute the KL divergence between the model and the reference model
|
|
|
|
|
@@ -1155,17 +1151,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
if self.beta != 0.0:
|
|
|
|
|
per_token_loss = per_token_loss + self.beta * per_token_kl
|
|
|
|
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
|
import ipdb
|
|
|
|
|
|
|
|
|
|
ipdb.set_trace()
|
|
|
|
|
dist.barrier()
|
|
|
|
|
if dist.get_rank() == 1:
|
|
|
|
|
import ipdb
|
|
|
|
|
|
|
|
|
|
ipdb.set_trace()
|
|
|
|
|
dist.barrier()
|
|
|
|
|
|
|
|
|
|
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
|
|
|
|
|
|
|
|
|
# Log metrics
|
|
|
|
|
@@ -1184,155 +1169,3 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
# def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
|
|
|
|
# # if self.args.sequence_parallel_degree > 1:
|
|
|
|
|
# if False:
|
|
|
|
|
# # Handle padding to make sequence length divisible by world size
|
|
|
|
|
# total_seq_len = input_ids.shape[1]
|
|
|
|
|
# if total_seq_len % self.local_world_size != 0:
|
|
|
|
|
# # Pad to make divisible
|
|
|
|
|
# pad_len = self.local_world_size - (
|
|
|
|
|
# total_seq_len % self.local_world_size
|
|
|
|
|
# )
|
|
|
|
|
# pad_token_id = self.processing_class.pad_token_id or 0
|
|
|
|
|
|
|
|
|
|
# # Pad input_ids
|
|
|
|
|
# padding = torch.full(
|
|
|
|
|
# (input_ids.shape[0], pad_len),
|
|
|
|
|
# pad_token_id,
|
|
|
|
|
# dtype=input_ids.dtype,
|
|
|
|
|
# device=input_ids.device,
|
|
|
|
|
# )
|
|
|
|
|
# input_ids = torch.cat([input_ids, padding], dim=1)
|
|
|
|
|
|
|
|
|
|
# # Pad attention mask
|
|
|
|
|
# if attention_mask is not None:
|
|
|
|
|
# attn_padding = torch.zeros(
|
|
|
|
|
# (attention_mask.shape[0], pad_len),
|
|
|
|
|
# dtype=attention_mask.dtype,
|
|
|
|
|
# device=attention_mask.device,
|
|
|
|
|
# )
|
|
|
|
|
# attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
|
|
|
|
|
|
|
|
|
# total_seq_len += pad_len
|
|
|
|
|
# logits_to_keep += pad_len
|
|
|
|
|
|
|
|
|
|
# # Share logits_to_keep across ranks to ensure consistency
|
|
|
|
|
# lt_keep = torch.tensor([logits_to_keep], device=input_ids.device)
|
|
|
|
|
# dist.broadcast(lt_keep, src=0, group=self.sp_group)
|
|
|
|
|
# logits_to_keep = lt_keep.item()
|
|
|
|
|
|
|
|
|
|
# # Split the sequence across ranks
|
|
|
|
|
# slice_size = total_seq_len // self.local_world_size
|
|
|
|
|
# start = self.local_rank * slice_size
|
|
|
|
|
# end = start + slice_size
|
|
|
|
|
|
|
|
|
|
# # Slice for this rank
|
|
|
|
|
# input_ids_slice = input_ids[:, start:end]
|
|
|
|
|
# attention_mask_slice = (
|
|
|
|
|
# attention_mask[:, start:end] if attention_mask is not None else None
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# # Calculate how many tokens this rank needs to keep
|
|
|
|
|
# tokens_before_slice = self.local_rank * slice_size
|
|
|
|
|
# local_logits_to_keep = 0
|
|
|
|
|
|
|
|
|
|
# if tokens_before_slice < logits_to_keep:
|
|
|
|
|
# # This rank has tokens we need to keep
|
|
|
|
|
# local_logits_to_keep = min(
|
|
|
|
|
# slice_size, logits_to_keep - tokens_before_slice
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# # Run the model on our slice
|
|
|
|
|
# if local_logits_to_keep > 0:
|
|
|
|
|
# logits = model(
|
|
|
|
|
# input_ids=input_ids_slice,
|
|
|
|
|
# attention_mask=attention_mask_slice,
|
|
|
|
|
# logits_to_keep=local_logits_to_keep + 1,
|
|
|
|
|
# ).logits
|
|
|
|
|
# if self.local_rank == self.local_world_size - 1:
|
|
|
|
|
# logits = logits[:, :-1, :]
|
|
|
|
|
|
|
|
|
|
# # Get the relevant input_ids for computing log probs
|
|
|
|
|
# # Ensure this is the correct slice that corresponds to the logits
|
|
|
|
|
# relevant_input_ids = input_ids_slice[:, -local_logits_to_keep:]
|
|
|
|
|
# else:
|
|
|
|
|
# # Create empty logits with correct shape if we don't need to keep any
|
|
|
|
|
# vocab_size = model.config.vocab_size
|
|
|
|
|
# logits = torch.zeros(
|
|
|
|
|
# (input_ids.shape[0], 0, vocab_size),
|
|
|
|
|
# dtype=torch.float32,
|
|
|
|
|
# device=input_ids.device,
|
|
|
|
|
# )
|
|
|
|
|
# relevant_input_ids = torch.zeros(
|
|
|
|
|
# (input_ids.shape[0], 0),
|
|
|
|
|
# dtype=torch.float32,
|
|
|
|
|
# device=input_ids.device,
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# # Temperature scaling
|
|
|
|
|
# logits = logits / self.temperature
|
|
|
|
|
|
|
|
|
|
# print(f"{dist.get_rank()}: logits.shape: {logits.shape}")
|
|
|
|
|
# print(
|
|
|
|
|
# f"{dist.get_rank()}: relevant_input_ids.shape: {relevant_input_ids.shape}"
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# return selective_log_softmax(logits, relevant_input_ids)
|
|
|
|
|
|
|
|
|
|
# # All-gather results across SP group with proper shape handling
|
|
|
|
|
# # local_shape = torch.tensor([logits.shape[1]], device=logits.device)
|
|
|
|
|
# # all_shapes = [
|
|
|
|
|
# # torch.zeros_like(local_shape) for _ in range(self.local_world_size)
|
|
|
|
|
# # ]
|
|
|
|
|
# # dist.all_gather(all_shapes, local_shape, group=self.sp_group)
|
|
|
|
|
|
|
|
|
|
# # # Create full tensor to hold the complete result
|
|
|
|
|
# # full_logits = torch.zeros(
|
|
|
|
|
# # (input_ids.shape[0], logits_to_keep, model.config.vocab_size),
|
|
|
|
|
# # dtype=torch.float32,
|
|
|
|
|
# # device=input_ids.device,
|
|
|
|
|
# # )
|
|
|
|
|
|
|
|
|
|
# # # Calculate positions for each rank's contribution
|
|
|
|
|
# # position = 0
|
|
|
|
|
# # for i in range(self.local_world_size):
|
|
|
|
|
# # shape = all_shapes[i].item()
|
|
|
|
|
# # if i < self.local_rank:
|
|
|
|
|
# # position += shape
|
|
|
|
|
|
|
|
|
|
# # # Add our contribution to the full tensor
|
|
|
|
|
# # if local_logits_to_keep > 0:
|
|
|
|
|
# # # Make sure we're not exceeding bounds
|
|
|
|
|
# # end_pos = min(position + local_logits_to_keep, logits_to_keep)
|
|
|
|
|
# # copy_size = end_pos - position
|
|
|
|
|
|
|
|
|
|
# # if dist.get_rank() == 0:
|
|
|
|
|
# # import ipdb; ipdb.set_trace()
|
|
|
|
|
# # dist.barrier()
|
|
|
|
|
# # if dist.get_rank() == 1:
|
|
|
|
|
# # import ipdb; ipdb.set_trace()
|
|
|
|
|
# # dist.barrier()
|
|
|
|
|
|
|
|
|
|
# # if copy_size > 0:
|
|
|
|
|
# # full_logits[:, position:end_pos, :] = logits[:, :copy_size, :]
|
|
|
|
|
|
|
|
|
|
# # # Combine results via all-reduce
|
|
|
|
|
# # dist.all_reduce(full_logits, op=dist.ReduceOp.SUM, group=self.sp_group)
|
|
|
|
|
|
|
|
|
|
# # # Remove the last prediction token
|
|
|
|
|
# # full_logits = full_logits[:, :-1, :]
|
|
|
|
|
|
|
|
|
|
# # # Get the relevant input_ids for computing log probs
|
|
|
|
|
# # # Ensure this is the correct slice that corresponds to the logits
|
|
|
|
|
# # relevant_input_ids = input_ids[:, -logits_to_keep:]
|
|
|
|
|
|
|
|
|
|
# # # Temperature scaling
|
|
|
|
|
# # full_logits = full_logits / self.temperature
|
|
|
|
|
|
|
|
|
|
# # return selective_log_softmax(full_logits, relevant_input_ids)
|
|
|
|
|
# else:
|
|
|
|
|
# return super()._get_per_token_logps(
|
|
|
|
|
# model, input_ids, attention_mask, logits_to_keep
|
|
|
|
|
# )
|
|
|
|
|
|