Compare commits

..

3 Commits

Author SHA1 Message Date
Salman Mohammadi
1a09d5e844 some refactoring 2025-02-19 17:35:35 +00:00
Salman Mohammadi
cf61b4aba7 Merge branch 'main' into grpo_liger 2025-02-19 16:17:42 +00:00
Salman Mohammadi
14d274efe6 WIP liger support 2025-02-19 15:34:32 +00:00
22 changed files with 291 additions and 202 deletions

View File

@@ -4,10 +4,6 @@ on:
pull_request: pull_request:
paths: paths:
- 'tests/e2e/multigpu/*.py' - 'tests/e2e/multigpu/*.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -37,11 +37,15 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f: with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents) f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile( cicd_image = (
pathlib.Path(temp_dir) / "Dockerfile", Image.from_dockerfile(
force_build=True, pathlib.Path(temp_dir) / "Dockerfile",
gpu="A10G", force_build=True,
).env(df_args) gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. # bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool] include_tokens_per_second:
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.49.0 transformers==4.48.3
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.3.0 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.15.1 trl==0.15.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -123,6 +123,8 @@ class ModalCloud(Cloud):
if env := self.get_env(): if env := self.get_env():
image = image.env(env) image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image return image
def get_secrets(self): def get_secrets(self):

View File

@@ -59,7 +59,6 @@ from axolotl.core.training_args import (
AxolotlTrainingArguments, AxolotlTrainingArguments,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
@@ -747,11 +746,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64 data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.sp_ulysses_degree:
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
USPRingAttnType.ZIGZAG,
sp_ulysses_degree=self.cfg.sp_ulysses_degree
)
if self.cfg.reward_model: if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len data_collator_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -30,31 +31,21 @@ class GRPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {} training_kwargs = [
if cfg.trl and cfg.trl.use_vllm: "use_vllm",
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm "vllm_device",
if cfg.trl and cfg.trl.vllm_device: "vllm_gpu_memory_utilization",
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device "vllm_max_model_len",
else: "vllm_dtype",
grpo_args_kwargs["vllm_device"] = "auto" "use_liger_loss",
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: "num_generations",
grpo_args_kwargs[ "log_completions",
"vllm_gpu_memory_utilization" "sync_ref_model",
] = cfg.trl.vllm_gpu_memory_utilization "ref_model_mixup_alpha",
if cfg.trl and cfg.trl.vllm_max_model_len: "ref_model_sync_steps",
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len "max_completion_length",
if cfg.trl and cfg.trl.num_generations: ]
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
@@ -71,9 +62,7 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[ trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
""" """
Axolotl GRPO Config for GRPO training Axolotl GRPO Config for GRPO training
""" """
use_liger_loss: bool = False

View File

@@ -1,13 +1,24 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from contextlib import contextmanager, nullcontext
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation from trl.models import unwrap_model_for_generation
from axolotl.core.trainers.base import SchedulerMixin from axolotl.core.trainers.base import SchedulerMixin
from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors # mypy: ignore-errors
@@ -20,7 +31,20 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_liger_loss = kwargs["args"].use_liger_loss
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
beta=self.beta,
compiled=is_compiled_module(self.model),
use_ref_model=True,
num_generations=self.args.num_generations,
)
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested # Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing: if kwargs["args"].gradient_checkpointing:
@@ -29,9 +53,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model.config.use_cache = False self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT # Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr( if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"):
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable() self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models # Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"): elif hasattr(self.model, "gradient_checkpointing_enable"):
@@ -39,15 +61,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition # pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing( def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model.""" """Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin # pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = ( use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
or gradient_checkpointing_kwargs["use_reentrant"]
) )
if use_reentrant: if use_reentrant:
@@ -58,9 +77,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def make_inputs_require_grad(module, input, output): def make_inputs_require_grad(module, input, output):
output.requires_grad_(True) output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook( model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
make_inputs_require_grad
)
return model return model
# pylint: enable=unused-argument,redefined-builtin # pylint: enable=unused-argument,redefined-builtin
@@ -72,25 +89,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
gather_deepspeed3_params=self.args.ds3_gather_for_generation, gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model: ) as unwrapped_model:
if is_compiled_module(unwrapped_model): if is_compiled_module(unwrapped_model):
unwrapped_model = ( unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model): if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter() unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes # Remove base_model and base_layer prefixes
state_dict = { state_dict = {
k.removeprefix("base_model.model.") k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items() for k, v in state_dict.items()
} }
# Remove values with adapter prefix (example: "_lora") # Remove values with adapter prefix (example: "_lora")
state_dict = { state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module # When module to save, remove its prefix and discard the original module
state_dict = { state_dict = {
k.replace("modules_to_save.default.", ""): v k.replace("modules_to_save.default.", ""): v
@@ -99,10 +109,218 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
} }
else: else:
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
llm_model = ( llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights(state_dict.items())
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.use_liger_loss:
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
) )
llm_model.load_weights(state_dict.items()) completion_ids = completion_ids[process_slice]
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter() # Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(
prompt_inputs["input_ids"], self.num_generations, dim=0
)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
hidden_states = outputs.last_hidden_state[:, :-1]
logits = outputs.logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps), hidden_states
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
self.ref_model, prompt_completion_ids, num_logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger
# Compute the KL divergence between the model and the reference model
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# )
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# done in liger
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# done in liger
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger
# x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings()
if self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
loss, metrics = self.grpo_loss_fn(
lm_head,
hidden_states, # this is the hidden states from the model
completion_mask,
rewards,
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
ref_input=ref_hidden_states, # this is the hidden states from the ref model
ref_weight=ref_weight,
ref_bias=ref_bias,
)
else:
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

View File

@@ -206,16 +206,6 @@ class AxolotlTrainingMixins:
}, },
) )
sp_ulysses_degree: Optional[int] = field(
default=None,
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
)
sp_ring_degree: Optional[int] = field(
default=None,
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -1,45 +0,0 @@
from enum import Enum
from functools import partial
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
from axolotl.utils.distributed import get_world_size, get_rank
class USPRingAttnType(Enum):
BASIC = "basic"
ZIGZAG = "zigzag"
STRIPE = "stripe"
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
fa_forward = build_usp_fa_forward(ring_impl_type)
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
fn = EXTRACT_FUNC_DICT["basic"]
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
# map bad key upstream
elif ring_impl_type == USPRingAttnType.STRIPE:
fn = EXTRACT_FUNC_DICT["strip"]
world_size = get_world_size()
rd = world_size // sp_ulysses_degree
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
def set_usp_parallel_group(sp_ulysses_degree):
"""
setup distributed parallel group for USP attention
make sure this gets called before building any USP attention modules
:param sp_ulysses_degree:
:return:
"""
world_size = get_world_size()
rank = get_rank()
sp_ring_degree = world_size // sp_ulysses_degree
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)

View File

@@ -1,36 +0,0 @@
from enum import Enum
from typing import Optional, Tuple, Callable
import torch
from yunchang import LongContextAttention
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
usp_attn = LongContextAttention(ring_impl_type.value)
def flash_attention_forward(
module: torch.nn.Module, # pylint: disable=unused-argument
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, None]:
attn_output = usp_attn(
query,
key,
value,
dropout_p=dropout,
softmax_scale=scaling,
causal=True,
softcap=softcap,
)
return attn_output, None
return flash_attention_forward

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
**_kwargs, **_kwargs,
): ):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0: if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join( checkpoint_folder = os.path.join(
args.output_dir, args.output_dir,

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row)) dict(zip(feature_names, row))
) )
for key, val in tokenized_prompt.items(): for key, val in tokenized_prompt.items():
res[key].append(val) for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary # If there are no examples left, return an empty dictionary
if not res: if not res:

View File

@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union, Callable from typing import Any, Optional, Union
import numpy as np import numpy as np
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -53,7 +53,6 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sp_extract_fn: Optional[Callable] = None
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
labels = None labels = None
@@ -122,10 +121,6 @@ class DataCollatorForSeq2Seq:
return features return features
def seq_parallel_split(self, features):
if self.sp_extract_fn:
pass
return features
@dataclass @dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field( qlora_sharded_model_loading: Optional[bool] = Field(
default=False, default=False,
@@ -832,8 +831,6 @@ class AxolotlInputConfig(
eager_attention: Optional[bool] = None eager_attention: Optional[bool] = None
sp_ulysses_degree: Optional[int] = None
unsloth_cross_entropy_loss: Optional[bool] = None unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None

View File

@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9 ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64 ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
) )
try: try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(get_dataset_lengths(dataset))
min_input_len = np.min(ds_lengths) LOG.debug(f"min_input_len: {min_input_len}")
LOG.info(f"min_input_len: {min_input_len}") max_input_len = np.max(get_dataset_lengths(dataset))
max_input_len = np.max(ds_lengths) LOG.debug(f"max_input_len: {max_input_len}")
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError: except AttributeError:
pass pass

View File

@@ -86,12 +86,6 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))
def get_rank():
if not is_distributed():
return 0
return dist.get_rank()
@contextmanager @contextmanager
def zero_only(): def zero_only():
""" """

View File

@@ -1321,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits: if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq" lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora: if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.") LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np import numpy as np
def get_dataset_lengths(dataset, from_arrow=False): def get_dataset_lengths(dataset):
if "length" in dataset.column_names: if "length" in dataset.column_names:
lengths = np.array(dataset["length"]) lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names: elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"] position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids]) lengths = np.array([x[-1] + 1 for x in position_ids])
else: else:
if from_arrow: input_ids = dataset["input_ids"]
input_ids = dataset.data.column("input_ids") lengths = np.array([len(seq) for seq in input_ids])
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths return lengths

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing or cfg.sp_ulysses_degree: elif cfg.sample_packing:
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"