progress
This commit is contained in:
@@ -617,9 +617,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.args.use_vllm:
|
||||
# First, have main process load weights if needed
|
||||
@@ -627,6 +624,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
@@ -648,7 +646,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text)
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
@@ -719,6 +719,75 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
) # we only need to compute the logits for the completion tokens
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Pad sequence to be divisible by SP degree if needed
|
||||
total_seq_len = prompt_completion_ids.shape[1]
|
||||
if total_seq_len % self.local_world_size != 0:
|
||||
pad_len = self.local_world_size - (
|
||||
total_seq_len % self.local_world_size
|
||||
)
|
||||
pad_token_id = self.processing_class.pad_token_id or 0
|
||||
|
||||
# Pad input_ids and attention_mask
|
||||
padding = torch.full(
|
||||
(prompt_completion_ids.shape[0], pad_len),
|
||||
pad_token_id,
|
||||
dtype=prompt_completion_ids.dtype,
|
||||
device=prompt_completion_ids.device,
|
||||
)
|
||||
prompt_completion_ids = torch.cat(
|
||||
[prompt_completion_ids, padding], dim=1
|
||||
)
|
||||
|
||||
attn_padding = torch.zeros(
|
||||
(attention_mask.shape[0], pad_len),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||
|
||||
total_seq_len += pad_len
|
||||
logits_to_keep += pad_len
|
||||
|
||||
# Split the sequence
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
|
||||
# Get our slice
|
||||
prompt_completion_ids = prompt_completion_ids[:, start:end]
|
||||
attention_mask = attention_mask[:, start:end]
|
||||
|
||||
# Calculate how many completion tokens each rank should process
|
||||
prompt_len = prompt_ids.size(1)
|
||||
completion_len = completion_ids.size(
|
||||
1
|
||||
) # This is equal to logits_to_keep
|
||||
|
||||
# Calculate where our slice starts and ends relative to the completion tokens
|
||||
if start >= prompt_len:
|
||||
# Slice starts within the completion section
|
||||
start_in_completion = start - prompt_len
|
||||
end_in_completion = min(end - prompt_len, completion_len)
|
||||
logits_to_keep = end_in_completion - start_in_completion
|
||||
completion_mask = completion_mask[
|
||||
:, start_in_completion:end_in_completion
|
||||
]
|
||||
elif end <= prompt_len:
|
||||
# Slice is entirely within the prompt section (no completion tokens)
|
||||
logits_to_keep = 0
|
||||
completion_mask = torch.zeros(
|
||||
(completion_mask.size(0), 0), device=completion_mask.device
|
||||
)
|
||||
else:
|
||||
# Slice contains the boundary between prompt and completion
|
||||
start_in_completion = 0
|
||||
end_in_completion = min(end - prompt_len, completion_len)
|
||||
logits_to_keep = end_in_completion - start_in_completion
|
||||
completion_mask = completion_mask[
|
||||
:, start_in_completion:end_in_completion
|
||||
]
|
||||
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
@@ -731,6 +800,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
if self.beta == 0.0:
|
||||
ref_per_token_logps = None
|
||||
elif self.ref_model is not None:
|
||||
print(f"{dist.get_rank()}: prompt_completion_ids.shape: {prompt_completion_ids.shape}")
|
||||
print(f"{dist.get_rank()}: attention_mask.shape: {attention_mask.shape}")
|
||||
print(f"{dist.get_rank()}: logits_to_keep: {logits_to_keep}")
|
||||
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model,
|
||||
prompt_completion_ids,
|
||||
@@ -943,132 +1016,323 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"advantages": advantages,
|
||||
}
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
print(f"{self.local_rank}: input_ids.shape: {input_ids.shape}")
|
||||
print(f"{self.local_rank}: input_ids[0, :20]: {input_ids[0, :20]}")
|
||||
print(f"{self.local_rank}: input_ids[0, -20:]: {input_ids[0, -20:]}")
|
||||
@profiling_decorator
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
|
||||
# Pad sequence if needed
|
||||
# Unpack inputs
|
||||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
||||
completion_ids, completion_mask = (
|
||||
inputs["completion_ids"],
|
||||
inputs["completion_mask"],
|
||||
)
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
logits_to_keep = completion_ids.size(1)
|
||||
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Pad sequence to be divisible by SP degree if needed
|
||||
total_seq_len = input_ids.shape[1]
|
||||
remainder = total_seq_len % self.local_world_size
|
||||
if remainder != 0:
|
||||
to_pad = self.local_world_size - remainder
|
||||
if total_seq_len % self.local_world_size != 0:
|
||||
pad_len = self.local_world_size - (
|
||||
total_seq_len % self.local_world_size
|
||||
)
|
||||
pad_token_id = self.processing_class.pad_token_id or 0
|
||||
|
||||
# Pad input_ids and attention_mask
|
||||
padding = torch.full(
|
||||
(input_ids.shape[0], to_pad),
|
||||
(input_ids.shape[0], pad_len),
|
||||
pad_token_id,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
input_ids = torch.cat([input_ids, padding], dim=1)
|
||||
|
||||
# Also pad attention mask if it exists
|
||||
if attention_mask is not None:
|
||||
attn_padding = torch.zeros(
|
||||
(attention_mask.shape[0], to_pad),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||
attn_padding = torch.zeros(
|
||||
(attention_mask.shape[0], pad_len),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||
|
||||
# Update total_seq_len after padding
|
||||
total_seq_len += to_pad
|
||||
total_seq_len += pad_len
|
||||
logits_to_keep += pad_len
|
||||
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
# Split the sequence
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
|
||||
# Slice data for sequence parallel processing
|
||||
input_ids = input_ids[:, start:end]
|
||||
attention_mask = attention_mask[:, start:end]
|
||||
# Get our slice
|
||||
input_ids_slice = input_ids[:, start:end]
|
||||
attention_mask_slice = attention_mask[:, start:end]
|
||||
|
||||
# Calculate if this rank contains any tokens we need to keep
|
||||
tokens_before_our_slice = self.local_rank * slice_size
|
||||
print(f"{self.local_rank}: slice_size: {slice_size}")
|
||||
print(
|
||||
f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}"
|
||||
)
|
||||
if tokens_before_our_slice < logits_to_keep:
|
||||
# How many tokens from our slice are needed
|
||||
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
|
||||
logits_to_keep = min(slice_size, tokens_needed_from_slice)
|
||||
else:
|
||||
# This rank doesn't contain any tokens we need to keep
|
||||
logits_to_keep = 0
|
||||
# Calculate how many completion tokens each rank should process
|
||||
prompt_len = prompt_ids.size(1)
|
||||
completion_len = completion_ids.size(1) # This is equal to logits_to_keep
|
||||
|
||||
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
|
||||
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
logits_to_keep=logits_to_keep + 1,
|
||||
).logits
|
||||
logits = logits[
|
||||
:, :-1, :
|
||||
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
|
||||
print(f"{self.local_rank}: logits.shape: {logits.shape}")
|
||||
|
||||
# First, let all ranks know the shape of each rank's tensor
|
||||
local_shape = torch.tensor(
|
||||
[logits.shape[0], logits.shape[1], logits.shape[2]],
|
||||
device=logits.device,
|
||||
)
|
||||
all_shapes = [
|
||||
torch.zeros_like(local_shape) for _ in range(self.local_world_size)
|
||||
]
|
||||
dist.all_gather(all_shapes, local_shape, group=self.sp_group)
|
||||
|
||||
# Use a list-based approach to collect logits of different sizes
|
||||
if self.local_rank == 0:
|
||||
# Root process allocates space for receiving
|
||||
gathered_logits = []
|
||||
for shape in all_shapes:
|
||||
b, s, v = shape.tolist()
|
||||
gathered_logits.append(
|
||||
torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device)
|
||||
)
|
||||
else:
|
||||
gathered_logits = None
|
||||
|
||||
# Gather to rank 0
|
||||
dist.gather(logits, gathered_logits, dst=0, group=self.sp_group)
|
||||
|
||||
# On rank 0, concatenate and distribute the result
|
||||
if self.local_rank == 0:
|
||||
concatenated_logits = torch.cat(gathered_logits, dim=1)
|
||||
# Trim to keep only what we need
|
||||
if concatenated_logits.shape[1] > logits_to_keep:
|
||||
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
|
||||
else:
|
||||
concatenated_logits = torch.zeros(
|
||||
(logits.shape[0], logits_to_keep, logits.shape[2]),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device,
|
||||
# Calculate where our slice starts and ends relative to the completion tokens
|
||||
if start >= prompt_len:
|
||||
# Slice starts within the completion section
|
||||
start_in_completion = start - prompt_len
|
||||
end_in_completion = min(end - prompt_len, completion_len)
|
||||
local_logits_to_keep = end_in_completion - start_in_completion
|
||||
completion_mask = completion_mask[
|
||||
:, start_in_completion:end_in_completion
|
||||
]
|
||||
elif end <= prompt_len:
|
||||
# Slice is entirely within the prompt section (no completion tokens)
|
||||
local_logits_to_keep = 0
|
||||
completion_mask = torch.zeros(
|
||||
(completion_mask.size(0), 0), device=completion_mask.device
|
||||
)
|
||||
else:
|
||||
# Slice contains the boundary between prompt and completion
|
||||
start_in_completion = 0
|
||||
end_in_completion = min(end - prompt_len, completion_len)
|
||||
local_logits_to_keep = end_in_completion - start_in_completion
|
||||
completion_mask = completion_mask[
|
||||
:, start_in_completion:end_in_completion
|
||||
]
|
||||
|
||||
# Broadcast the result back to all ranks
|
||||
dist.broadcast(concatenated_logits, src=0, group=self.sp_group)
|
||||
logits = concatenated_logits
|
||||
# Run model on our slice
|
||||
if local_logits_to_keep > 0:
|
||||
# Get logits with enough context to compute log probs
|
||||
logits = model(
|
||||
input_ids=input_ids_slice,
|
||||
attention_mask=attention_mask_slice,
|
||||
logits_to_keep=local_logits_to_keep + 1,
|
||||
).logits
|
||||
|
||||
input_ids = input_ids[:, -logits_to_keep:]
|
||||
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
||||
# See https://github.com/huggingface/trl/issues/2770
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
# Divide logits by sampling temperature.
|
||||
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
|
||||
logits = logits / self.temperature
|
||||
# Remove the last prediction token on the last rank
|
||||
# if self.local_rank == self.local_world_size - 1:
|
||||
# logits = logits[:, :-1, :]
|
||||
logits = logits[:, :-1, :]
|
||||
|
||||
dist.barrier()
|
||||
|
||||
return selective_log_softmax(
|
||||
logits, input_ids
|
||||
) # compute logprobs for the input tokens
|
||||
# Compute log probabilities on our local slice
|
||||
local_input_ids = input_ids_slice[:, -local_logits_to_keep:]
|
||||
logits = logits / self.temperature
|
||||
per_token_logps = selective_log_softmax(logits, local_input_ids)
|
||||
else:
|
||||
# This rank doesn't have any tokens to keep
|
||||
per_token_logps = torch.zeros(
|
||||
(input_ids.shape[0], 0),
|
||||
dtype=torch.float32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
else:
|
||||
super()._get_per_token_logps(
|
||||
per_token_logps = super()._get_per_token_logps(
|
||||
model, input_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
if self.beta != 0.0:
|
||||
ref_per_token_logps = inputs["ref_per_token_logps"]
|
||||
per_token_kl = (
|
||||
torch.exp(ref_per_token_logps - per_token_logps)
|
||||
- (ref_per_token_logps - per_token_logps)
|
||||
- 1
|
||||
)
|
||||
|
||||
# Compute the loss
|
||||
advantages = inputs["advantages"]
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
|
||||
# and use per_token_logps.detach() instead.
|
||||
old_per_token_logps = (
|
||||
inputs["old_per_token_logps"]
|
||||
if self.num_iterations > 1
|
||||
else per_token_logps.detach()
|
||||
)
|
||||
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
||||
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
||||
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
||||
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
||||
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
||||
|
||||
if self.beta != 0.0:
|
||||
per_token_loss = per_token_loss + self.beta * per_token_kl
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 1:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
dist.barrier()
|
||||
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
||||
|
||||
# Log metrics
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
|
||||
if self.beta != 0.0:
|
||||
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
||||
self._metrics[mode]["kl"].append(
|
||||
self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
||||
)
|
||||
|
||||
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
||||
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
||||
self._metrics[mode]["clip_ratio"].append(
|
||||
self.accelerator.gather_for_metrics(clip_ratio).mean().item()
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
# def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
# # if self.args.sequence_parallel_degree > 1:
|
||||
# if False:
|
||||
# # Handle padding to make sequence length divisible by world size
|
||||
# total_seq_len = input_ids.shape[1]
|
||||
# if total_seq_len % self.local_world_size != 0:
|
||||
# # Pad to make divisible
|
||||
# pad_len = self.local_world_size - (
|
||||
# total_seq_len % self.local_world_size
|
||||
# )
|
||||
# pad_token_id = self.processing_class.pad_token_id or 0
|
||||
|
||||
# # Pad input_ids
|
||||
# padding = torch.full(
|
||||
# (input_ids.shape[0], pad_len),
|
||||
# pad_token_id,
|
||||
# dtype=input_ids.dtype,
|
||||
# device=input_ids.device,
|
||||
# )
|
||||
# input_ids = torch.cat([input_ids, padding], dim=1)
|
||||
|
||||
# # Pad attention mask
|
||||
# if attention_mask is not None:
|
||||
# attn_padding = torch.zeros(
|
||||
# (attention_mask.shape[0], pad_len),
|
||||
# dtype=attention_mask.dtype,
|
||||
# device=attention_mask.device,
|
||||
# )
|
||||
# attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||
|
||||
# total_seq_len += pad_len
|
||||
# logits_to_keep += pad_len
|
||||
|
||||
# # Share logits_to_keep across ranks to ensure consistency
|
||||
# lt_keep = torch.tensor([logits_to_keep], device=input_ids.device)
|
||||
# dist.broadcast(lt_keep, src=0, group=self.sp_group)
|
||||
# logits_to_keep = lt_keep.item()
|
||||
|
||||
# # Split the sequence across ranks
|
||||
# slice_size = total_seq_len // self.local_world_size
|
||||
# start = self.local_rank * slice_size
|
||||
# end = start + slice_size
|
||||
|
||||
# # Slice for this rank
|
||||
# input_ids_slice = input_ids[:, start:end]
|
||||
# attention_mask_slice = (
|
||||
# attention_mask[:, start:end] if attention_mask is not None else None
|
||||
# )
|
||||
|
||||
# # Calculate how many tokens this rank needs to keep
|
||||
# tokens_before_slice = self.local_rank * slice_size
|
||||
# local_logits_to_keep = 0
|
||||
|
||||
# if tokens_before_slice < logits_to_keep:
|
||||
# # This rank has tokens we need to keep
|
||||
# local_logits_to_keep = min(
|
||||
# slice_size, logits_to_keep - tokens_before_slice
|
||||
# )
|
||||
|
||||
# # Run the model on our slice
|
||||
# if local_logits_to_keep > 0:
|
||||
# logits = model(
|
||||
# input_ids=input_ids_slice,
|
||||
# attention_mask=attention_mask_slice,
|
||||
# logits_to_keep=local_logits_to_keep + 1,
|
||||
# ).logits
|
||||
# if self.local_rank == self.local_world_size - 1:
|
||||
# logits = logits[:, :-1, :]
|
||||
|
||||
# # Get the relevant input_ids for computing log probs
|
||||
# # Ensure this is the correct slice that corresponds to the logits
|
||||
# relevant_input_ids = input_ids_slice[:, -local_logits_to_keep:]
|
||||
# else:
|
||||
# # Create empty logits with correct shape if we don't need to keep any
|
||||
# vocab_size = model.config.vocab_size
|
||||
# logits = torch.zeros(
|
||||
# (input_ids.shape[0], 0, vocab_size),
|
||||
# dtype=torch.float32,
|
||||
# device=input_ids.device,
|
||||
# )
|
||||
# relevant_input_ids = torch.zeros(
|
||||
# (input_ids.shape[0], 0),
|
||||
# dtype=torch.float32,
|
||||
# device=input_ids.device,
|
||||
# )
|
||||
|
||||
# # Temperature scaling
|
||||
# logits = logits / self.temperature
|
||||
|
||||
# print(f"{dist.get_rank()}: logits.shape: {logits.shape}")
|
||||
# print(
|
||||
# f"{dist.get_rank()}: relevant_input_ids.shape: {relevant_input_ids.shape}"
|
||||
# )
|
||||
|
||||
# return selective_log_softmax(logits, relevant_input_ids)
|
||||
|
||||
# # All-gather results across SP group with proper shape handling
|
||||
# # local_shape = torch.tensor([logits.shape[1]], device=logits.device)
|
||||
# # all_shapes = [
|
||||
# # torch.zeros_like(local_shape) for _ in range(self.local_world_size)
|
||||
# # ]
|
||||
# # dist.all_gather(all_shapes, local_shape, group=self.sp_group)
|
||||
|
||||
# # # Create full tensor to hold the complete result
|
||||
# # full_logits = torch.zeros(
|
||||
# # (input_ids.shape[0], logits_to_keep, model.config.vocab_size),
|
||||
# # dtype=torch.float32,
|
||||
# # device=input_ids.device,
|
||||
# # )
|
||||
|
||||
# # # Calculate positions for each rank's contribution
|
||||
# # position = 0
|
||||
# # for i in range(self.local_world_size):
|
||||
# # shape = all_shapes[i].item()
|
||||
# # if i < self.local_rank:
|
||||
# # position += shape
|
||||
|
||||
# # # Add our contribution to the full tensor
|
||||
# # if local_logits_to_keep > 0:
|
||||
# # # Make sure we're not exceeding bounds
|
||||
# # end_pos = min(position + local_logits_to_keep, logits_to_keep)
|
||||
# # copy_size = end_pos - position
|
||||
|
||||
# # if dist.get_rank() == 0:
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
# # dist.barrier()
|
||||
# # if dist.get_rank() == 1:
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
# # dist.barrier()
|
||||
|
||||
# # if copy_size > 0:
|
||||
# # full_logits[:, position:end_pos, :] = logits[:, :copy_size, :]
|
||||
|
||||
# # # Combine results via all-reduce
|
||||
# # dist.all_reduce(full_logits, op=dist.ReduceOp.SUM, group=self.sp_group)
|
||||
|
||||
# # # Remove the last prediction token
|
||||
# # full_logits = full_logits[:, :-1, :]
|
||||
|
||||
# # # Get the relevant input_ids for computing log probs
|
||||
# # # Ensure this is the correct slice that corresponds to the logits
|
||||
# # relevant_input_ids = input_ids[:, -logits_to_keep:]
|
||||
|
||||
# # # Temperature scaling
|
||||
# # full_logits = full_logits / self.temperature
|
||||
|
||||
# # return selective_log_softmax(full_logits, relevant_input_ids)
|
||||
# else:
|
||||
# return super()._get_per_token_logps(
|
||||
# model, input_ids, attention_mask, logits_to_keep
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user