Compare commits

..

3 Commits

Author SHA1 Message Date
Salman Mohammadi
1a09d5e844 some refactoring 2025-02-19 17:35:35 +00:00
Salman Mohammadi
cf61b4aba7 Merge branch 'main' into grpo_liger 2025-02-19 16:17:42 +00:00
Salman Mohammadi
14d274efe6 WIP liger support 2025-02-19 15:34:32 +00:00
22 changed files with 291 additions and 202 deletions

View File

@@ -4,10 +4,6 @@ on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -37,11 +37,15 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
).env(df_args)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[])

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
include_tokens_per_second:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.49.0
transformers==4.48.3
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.15.1
trl==0.15.0
optimum==1.16.2
hf_transfer

View File

@@ -123,6 +123,8 @@ class ModalCloud(Cloud):
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):

View File

@@ -59,7 +59,6 @@ from axolotl.core.training_args import (
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -747,11 +746,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.sp_ulysses_degree:
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
USPRingAttnType.ZIGZAG,
sp_ulysses_degree=self.cfg.sp_ulysses_degree
)
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -30,31 +31,21 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
training_kwargs = [
"use_vllm",
"vllm_device",
"vllm_gpu_memory_utilization",
"vllm_max_model_len",
"vllm_dtype",
"use_liger_loss",
"num_generations",
"log_completions",
"sync_ref_model",
"ref_model_mixup_alpha",
"ref_model_sync_steps",
"max_completion_length",
]
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
return grpo_args_kwargs
@classmethod
@@ -71,9 +62,7 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod

View File

@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""
use_liger_loss: bool = False

View File

@@ -1,13 +1,24 @@
"""
Axolotl GRPO trainer
"""
from contextlib import contextmanager, nullcontext
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
import torch
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from axolotl.core.trainers.base import SchedulerMixin
from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors
@@ -20,7 +31,20 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_liger_loss = kwargs["args"].use_liger_loss
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
beta=self.beta,
compiled=is_compiled_module(self.model),
use_ref_model=True,
num_generations=self.args.num_generations,
)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
@@ -29,9 +53,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
@@ -39,15 +61,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
@@ -58,9 +77,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
# pylint: enable=unused-argument,redefined-builtin
@@ -72,25 +89,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
@@ -99,10 +109,218 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.use_liger_loss:
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(
prompt_inputs["input_ids"], self.num_generations, dim=0
)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
hidden_states = outputs.last_hidden_state[:, :-1]
logits = outputs.logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps), hidden_states
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
self.ref_model, prompt_completion_ids, num_logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger
# Compute the KL divergence between the model and the reference model
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# )
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# done in liger
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# done in liger
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger
# x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings()
if self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
loss, metrics = self.grpo_loss_fn(
lm_head,
hidden_states, # this is the hidden states from the model
completion_mask,
rewards,
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
ref_input=ref_hidden_states, # this is the hidden states from the ref model
ref_weight=ref_weight,
ref_bias=ref_bias,
)
else:
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

View File

@@ -206,16 +206,6 @@ class AxolotlTrainingMixins:
},
)
sp_ulysses_degree: Optional[int] = field(
default=None,
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
)
sp_ring_degree: Optional[int] = field(
default=None,
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -1,45 +0,0 @@
from enum import Enum
from functools import partial
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
from axolotl.utils.distributed import get_world_size, get_rank
class USPRingAttnType(Enum):
BASIC = "basic"
ZIGZAG = "zigzag"
STRIPE = "stripe"
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
fa_forward = build_usp_fa_forward(ring_impl_type)
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
fn = EXTRACT_FUNC_DICT["basic"]
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
# map bad key upstream
elif ring_impl_type == USPRingAttnType.STRIPE:
fn = EXTRACT_FUNC_DICT["strip"]
world_size = get_world_size()
rd = world_size // sp_ulysses_degree
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
def set_usp_parallel_group(sp_ulysses_degree):
"""
setup distributed parallel group for USP attention
make sure this gets called before building any USP attention modules
:param sp_ulysses_degree:
:return:
"""
world_size = get_world_size()
rank = get_rank()
sp_ring_degree = world_size // sp_ulysses_degree
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)

View File

@@ -1,36 +0,0 @@
from enum import Enum
from typing import Optional, Tuple, Callable
import torch
from yunchang import LongContextAttention
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
usp_attn = LongContextAttention(ring_impl_type.value)
def flash_attention_forward(
module: torch.nn.Module, # pylint: disable=unused-argument
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, None]:
attn_output = usp_attn(
query,
key,
value,
dropout_p=dropout,
softmax_scale=scaling,
causal=True,
softcap=softcap,
)
return attn_output, None
return flash_attention_forward

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
res[key].append(val)
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:

View File

@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union, Callable
from typing import Any, Optional, Union
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -53,7 +53,6 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sp_extract_fn: Optional[Callable] = None
def __call__(self, features, return_tensors=None):
labels = None
@@ -122,10 +121,6 @@ class DataCollatorForSeq2Seq:
return features
def seq_parallel_split(self, features):
if self.sp_extract_fn:
pass
return features
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
@@ -832,8 +831,6 @@ class AxolotlInputConfig(
eager_attention: Optional[bool] = None
sp_ulysses_degree: Optional[int] = None
unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None

View File

@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
)
try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass

View File

@@ -86,12 +86,6 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
def get_rank():
if not is_distributed():
return 0
return dist.get_rank()
@contextmanager
def zero_only():
"""

View File

@@ -1321,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np
def get_dataset_lengths(dataset, from_arrow=False):
def get_dataset_lengths(dataset):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
if from_arrow:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing or cfg.sp_ulysses_degree:
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"