Compare commits

..

3 Commits

Author SHA1 Message Date
Salman Mohammadi
1a09d5e844 some refactoring 2025-02-19 17:35:35 +00:00
Salman Mohammadi
cf61b4aba7 Merge branch 'main' into grpo_liger 2025-02-19 16:17:42 +00:00
Salman Mohammadi
14d274efe6 WIP liger support 2025-02-19 15:34:32 +00:00
14 changed files with 283 additions and 215 deletions

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
include_tokens_per_second:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.49.0
transformers==4.48.3
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.15.1
trl==0.15.0
optimum==1.16.2
hf_transfer

View File

@@ -831,9 +831,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention is True:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -30,31 +31,21 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
training_kwargs = [
"use_vllm",
"vllm_device",
"vllm_gpu_memory_utilization",
"vllm_max_model_len",
"vllm_dtype",
"use_liger_loss",
"num_generations",
"log_completions",
"sync_ref_model",
"ref_model_mixup_alpha",
"ref_model_sync_steps",
"max_completion_length",
]
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
return grpo_args_kwargs
@classmethod
@@ -71,9 +62,7 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod

View File

@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""
use_liger_loss: bool = False

View File

@@ -1,13 +1,24 @@
"""
Axolotl GRPO trainer
"""
from contextlib import contextmanager, nullcontext
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
import torch
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from axolotl.core.trainers.base import SchedulerMixin
from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors
@@ -20,7 +31,20 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_liger_loss = kwargs["args"].use_liger_loss
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
beta=self.beta,
compiled=is_compiled_module(self.model),
use_ref_model=True,
num_generations=self.args.num_generations,
)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
@@ -29,9 +53,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
@@ -39,15 +61,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
@@ -58,9 +77,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
# pylint: enable=unused-argument,redefined-builtin
@@ -72,25 +89,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
@@ -99,10 +109,218 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.use_liger_loss:
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(
prompt_inputs["input_ids"], self.num_generations, dim=0
)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
hidden_states = outputs.last_hidden_state[:, :-1]
logits = outputs.logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps), hidden_states
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
self.ref_model, prompt_completion_ids, num_logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger
# Compute the KL divergence between the model and the reference model
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# )
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# done in liger
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# done in liger
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger
# x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings()
if self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
loss, metrics = self.grpo_loss_fn(
lm_head,
hidden_states, # this is the hidden states from the model
completion_mask,
rewards,
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
ref_input=ref_hidden_states, # this is the hidden states from the ref model
ref_weight=ref_weight,
ref_bias=ref_bias,
)
else:
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,

View File

@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)
return torch.stack(results)
def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
totalseqlens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
res[key].append(val)
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
@@ -823,7 +822,6 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
@@ -1790,26 +1788,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):

View File

@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
)
try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass

View File

@@ -403,7 +403,7 @@ class ModelLoader:
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.flash_attention
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
@@ -707,13 +707,7 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1119,7 +1113,7 @@ class ModelLoader:
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)
@@ -1327,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np
def get_dataset_lengths(dataset, from_arrow=False):
def get_dataset_lengths(dataset):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
if from_arrow:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths