Compare commits

...

35 Commits

Author SHA1 Message Date
Wing Lian
c9d842ef2e test not deleting pythonpath for custom code bundling 2025-02-06 13:40:30 -05:00
Wing Lian
ecea44c902 fix num_processes in passing to accelerate 2025-02-06 13:39:46 -05:00
Wing Lian
4f9c57e95d check for src axolotl in PYTHONPATH before removing it 2025-02-06 13:26:23 -05:00
Wing Lian
3d38bc82b8 include vllm in build 2025-02-06 11:09:42 -05:00
Wing Lian
756a8332d6 set default on trl config 2025-02-05 22:17:10 -05:00
Wing Lian
aded9c500d refactor cfg.grpo_* to use cfg.trl.* 2025-02-05 20:41:14 -05:00
Wing Lian
3659d812f7 use cfg.max_completion_length, not sequence_len 2025-02-05 13:20:17 -05:00
Salman Mohammadi
bdb0f97082 adding 'reward_processing_classes' 2025-02-05 18:18:42 +00:00
Salman Mohammadi
65b6519447 adding 'reward_processing_classes' 2025-02-05 18:13:05 +00:00
Wing Lian
a1958b09de seperately include max_completion_len 2025-02-05 13:01:52 -05:00
Salman Mohammadi
b8f258817e adding reward fn verification 2025-02-05 13:30:02 +00:00
Wing Lian
753146b458 max_length moved to reward config 2025-02-04 11:06:26 -05:00
Wing Lian
d683c50113 fix config cls 2025-02-04 11:06:26 -05:00
Wing Lian
234cd8311e fix failure case in prompter loading 2025-02-04 11:06:26 -05:00
Wing Lian
f9893e3842 fix dpo config and add use_logits_to_keep 2025-02-04 11:06:26 -05:00
Wing Lian
ac1ebc58a8 add support for num_generations 2025-02-04 11:06:25 -05:00
Wing Lian
56f3b9f20f bump pydantic to support vllm 2025-02-04 11:06:25 -05:00
Wing Lian
2c1376d8c4 don't shrink embeddings unless told to 2025-02-04 11:06:25 -05:00
Wing Lian
3c7517fd55 add support for passing map kwargs to dataset map in rl 2025-02-04 11:06:25 -05:00
Wing Lian
1e94d7ef65 more fixes to get grpo working 2025-02-04 11:06:25 -05:00
Wing Lian
cfc7fe0df2 remove ununsable args kwargs 2025-02-04 11:06:25 -05:00
Wing Lian
3c4fe478cf be nice with self.cfg.dataset_processes 2025-02-04 11:06:25 -05:00
Wing Lian
c810599c66 order matters 2025-02-04 11:06:24 -05:00
Wing Lian
300ffc2cb6 make it a dataclass 2025-02-04 11:06:24 -05:00
Wing Lian
b1c4711145 load the class from strat 2025-02-04 11:06:24 -05:00
Wing Lian
d155849e2c use correct builder 2025-02-04 11:06:24 -05:00
Wing Lian
626db6cb84 collator for grpo and prompt loader 2025-02-04 11:06:24 -05:00
Wing Lian
79159b4871 support custom module prompt strategy for rl 2025-02-04 11:06:24 -05:00
Wing Lian
704ddd6ff1 honor skip prepare for rl 2025-02-04 11:06:24 -05:00
Wing Lian
54b0d3d0e8 passthrough dataset parser for dpo/grpo 2025-02-04 11:06:23 -05:00
Wing Lian
59ad21f2de refactor a bit for better grpo support 2025-02-04 11:06:23 -05:00
Wing Lian
57264b6491 respect dotenv for cli 2025-02-04 11:06:23 -05:00
Wing Lian
d495e41ba1 refactor dpo trainer into own module 2025-02-04 11:06:23 -05:00
Wing Lian
6067fe6c28 upgrade trl to 0.14.0 2025-02-04 11:06:23 -05:00
NanoCode012
a620d481e2 fix: drop long seq even if not sample packing (#2211)
* fix: drop long seq even if not sample packing

* fix: logging import

* fix: cfg passed being none

* fix: try to fix logging

* fix: refactor call to not use accelerate log

* fix: try to fix circular import issue

* fix: don't drop when skip prepare

* chore: remove duplicate line

* fix: update warning to mention that sequences will be trimmed

* fix: do not drop seq if input_ids don't exist

* fix: increase RM unittest sequence length to reduce trim warnings

* fix: solve conflicts

* fix: default min_seq_len in case of None
2025-02-04 09:43:35 -05:00
28 changed files with 548 additions and 229 deletions

View File

@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.1
transformers==4.48.2
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.13.0
trl==0.14.0
optimum==1.16.2
hf_transfer
@@ -26,7 +26,7 @@ sentencepiece
gradio==3.50.2
modal==0.70.5
pydantic==2.6.3
pydantic==2.10.6
addict
fire
PyYAML>=6.0

View File

@@ -153,5 +153,8 @@ setup(
"ray": [
"ray[train]",
],
"vllm": [
"vllm>=0.7.1",
],
},
)

View File

@@ -22,8 +22,11 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
# if "PYTHONPATH" in new_env:
# python_path = Path(new_env["PYTHONPATH"].split(":")[0])
# if python_path.joinpath("src", "axolotl").exists():
# # we don't want to use the automounted axolotl or unexpected behavior happens
# del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603

View File

@@ -12,6 +12,7 @@ from typing import Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
@@ -208,7 +209,7 @@ def train(
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
@@ -381,4 +382,5 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

View File

@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
AxolotlORPOTrainer,
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
AxolotlPRMConfig,
@@ -652,7 +653,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
@@ -965,10 +966,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
@@ -977,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -1001,11 +1004,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
@@ -1016,9 +1023,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
@@ -1047,8 +1062,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
@@ -1068,7 +1087,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -5,30 +5,21 @@ module for customized trainers
from __future__ import annotations
# pylint: disable=too-many-lines
import gc
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOTrainer,
DPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers

View File

@@ -0,0 +1,33 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlDPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
return AxolotlDPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
return training_args_kwargs

View File

@@ -0,0 +1,15 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""

View File

@@ -0,0 +1,125 @@
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any, Dict, Union
import torch
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.base import (
SchedulerMixin,
_sanitize_kwargs_for_ds_tagging,
_sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss

View File

@@ -0,0 +1,113 @@
"""
GRPO Specific Strategy for training
"""
import importlib
import inspect
import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
LOG = logging.getLogger("axolotl")
class GRPOStrategy:
"""
Strategy for GRPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
return AxolotlGRPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
return grpo_args_kwargs
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None
@classmethod
def get_blocklist_args_kwargs(cls):
return ["dataset_num_proc"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
"""
Returns the reward function from the given fully qualified name, or the path to the reward function model.
Args:
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
or a HF hub path to the reward model.
Raises:
ValueError: If the reward function does not accept at least two arguments.
Returns:
RewardFunc: A callable that accepts prompts and completions and returns rewards,
or a path to a reward model.
"""
try:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError(
"Reward function must accept at least two arguments: prompts: list and completions: list"
)
return reward_func
except ModuleNotFoundError:
# the user has passed a string (ideally indicating the path of a reward model)
LOG.info(
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func

View File

@@ -0,0 +1,15 @@
"""
Axolotl Specific Training Args
"""
from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -0,0 +1,14 @@
"""
Axolotl GRPO trainer
"""
from trl import GRPOTrainer
from axolotl.core.trainers.base import SchedulerMixin
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""

View File

@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)
if len(strategy.split(".")) > 1:
try:
importlib.import_module(
strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]),
)
module_base = ".".join(strategy.split(".")[:-2])
strategy = strategy.split(".")[-2]
except ModuleNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1])
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught

View File

@@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
@@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][

View File

@@ -0,0 +1,14 @@
"""
DPO prompt strategies passthrough/zero-processing strategy
"""
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(
sample, tokenizer=None
): # pylint: disable=possibly-unused-variable,unused-argument
return sample
return transform_fn

View File

@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from .trl import TrlConfig
LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -33,6 +35,7 @@ class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
@@ -663,14 +666,20 @@ class AxolotlInputConfig(
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: Optional[bool] = None
rl: Optional[RLType] = None
trl: Optional[TrlConfig] = Field(
default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore

View File

@@ -0,0 +1,32 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class TrlConfig(BaseModel):
"""
Input args for TRL.
"""
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
},
)
# GRPO specific args
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
dataset.save_to_disk(str(prepared_ds_path))
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
**map_kwargs,
)
return data_set
@@ -150,36 +151,45 @@ def load_prepare_preference_datasets(cfg):
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
if not cfg.skip_prepare_dataset:
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)

View File

@@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
md5,
retry_on_request_exceptions,
)
@@ -56,7 +57,7 @@ from axolotl.utils.trainer import (
process_datasets_for_packing,
)
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5)
@@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if not cfg.skip_prepare_dataset:
dataset = drop_long_seq_in_dataset(dataset, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")

View File

@@ -1,4 +1,5 @@
"""data handling helpers"""
import functools
import hashlib
import logging
@@ -6,10 +7,15 @@ import time
from enum import Enum
import huggingface_hub
import numpy as np
import requests
from datasets import Dataset
from datasets import Dataset, IterableDataset
LOG = logging.getLogger("axolotl")
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
LOG = logging.getLogger(__name__)
class RetryStrategy(Enum):
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
)
return train_dataset, eval_dataset, dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)
try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass
try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -1053,9 +1053,12 @@ class ModelLoader:
if self.cfg.resize_token_embeddings_to_32x
else len(self.tokenizer)
)
if (
hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings != embeddings_len
if hasattr(self.model, "get_input_embeddings") and (
self.model.get_input_embeddings().num_embeddings < embeddings_len
or (
self.model.get_input_embeddings().num_embeddings > embeddings_len
and self.cfg.shrink_embeddings
)
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions"""
import json
import math
import os
@@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
Works for both single-example (list[int]) or batched (list[list[int]]).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
@@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
try:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
@@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_long,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
def drop_no_trainable_tokens(sample):
"""
Drop samples if all labels are -100 (i.e., zero trainable tokens).
@@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
@@ -622,7 +576,7 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"):
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

View File

@@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"num_labels": 1,
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"sequence_len": 2048,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,