Compare commits
34 Commits
wait-distr
...
grpo-path
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9d842ef2e | ||
|
|
ecea44c902 | ||
|
|
4f9c57e95d | ||
|
|
3d38bc82b8 | ||
|
|
756a8332d6 | ||
|
|
aded9c500d | ||
|
|
3659d812f7 | ||
|
|
bdb0f97082 | ||
|
|
65b6519447 | ||
|
|
a1958b09de | ||
|
|
b8f258817e | ||
|
|
753146b458 | ||
|
|
d683c50113 | ||
|
|
234cd8311e | ||
|
|
f9893e3842 | ||
|
|
ac1ebc58a8 | ||
|
|
56f3b9f20f | ||
|
|
2c1376d8c4 | ||
|
|
3c7517fd55 | ||
|
|
1e94d7ef65 | ||
|
|
cfc7fe0df2 | ||
|
|
3c4fe478cf | ||
|
|
c810599c66 | ||
|
|
300ffc2cb6 | ||
|
|
b1c4711145 | ||
|
|
d155849e2c | ||
|
|
626db6cb84 | ||
|
|
79159b4871 | ||
|
|
704ddd6ff1 | ||
|
|
54b0d3d0e8 | ||
|
|
59ad21f2de | ||
|
|
57264b6491 | ||
|
|
d495e41ba1 | ||
|
|
6067fe6c28 |
@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.48.1
|
transformers==4.48.2
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.13.0
|
trl==0.14.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
@@ -26,7 +26,7 @@ sentencepiece
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
|
|
||||||
modal==0.70.5
|
modal==0.70.5
|
||||||
pydantic==2.6.3
|
pydantic==2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -153,5 +153,8 @@ setup(
|
|||||||
"ray": [
|
"ray": [
|
||||||
"ray[train]",
|
"ray[train]",
|
||||||
],
|
],
|
||||||
|
"vllm": [
|
||||||
|
"vllm>=0.7.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,8 +22,11 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
|||||||
|
|
||||||
# modal workaround so it doesn't use the automounted axolotl
|
# modal workaround so it doesn't use the automounted axolotl
|
||||||
new_env = copy.deepcopy(os.environ)
|
new_env = copy.deepcopy(os.environ)
|
||||||
if "PYTHONPATH" in new_env:
|
# if "PYTHONPATH" in new_env:
|
||||||
del new_env["PYTHONPATH"]
|
# python_path = Path(new_env["PYTHONPATH"].split(":")[0])
|
||||||
|
# if python_path.joinpath("src", "axolotl").exists():
|
||||||
|
# # we don't want to use the automounted axolotl or unexpected behavior happens
|
||||||
|
# del new_env["PYTHONPATH"]
|
||||||
|
|
||||||
# Propagate errors from subprocess.
|
# Propagate errors from subprocess.
|
||||||
if exit_code := subprocess.call( # nosec B603
|
if exit_code := subprocess.call( # nosec B603
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
@@ -208,7 +209,7 @@ def train(
|
|||||||
accelerate_args.append(str(main_process_port))
|
accelerate_args.append(str(main_process_port))
|
||||||
if "num_processes" in kwargs:
|
if "num_processes" in kwargs:
|
||||||
num_processes = kwargs.pop("num_processes", None)
|
num_processes = kwargs.pop("num_processes", None)
|
||||||
accelerate_args.append("--num-processes")
|
accelerate_args.append("--num_processes")
|
||||||
accelerate_args.append(str(num_processes))
|
accelerate_args.append(str(num_processes))
|
||||||
|
|
||||||
base_cmd = ["accelerate", "launch"]
|
base_cmd = ["accelerate", "launch"]
|
||||||
@@ -381,4 +382,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlDPOTrainer,
|
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlDPOConfig,
|
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
AxolotlORPOConfig,
|
AxolotlORPOConfig,
|
||||||
AxolotlPRMConfig,
|
AxolotlPRMConfig,
|
||||||
@@ -652,7 +653,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
@@ -965,10 +966,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
if self.cfg.dataset_processes:
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
if self.cfg.rl_beta:
|
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
@@ -977,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
training_args_cls = AxolotlCPOConfig
|
training_args_cls = AxolotlCPOConfig
|
||||||
training_args_kwargs["loss_type"] = "simpo"
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
@@ -1001,11 +1004,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kto_undesirable_weight or 1.0
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
elif self.cfg.rl == "grpo":
|
||||||
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
|
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
@@ -1016,9 +1023,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||||
|
training_args_kwargs[
|
||||||
|
"use_logits_to_keep"
|
||||||
|
] = self.cfg.dpo_use_logits_to_keep
|
||||||
|
|
||||||
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
|
if blocklist_key in training_args_kwargs:
|
||||||
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
@@ -1047,8 +1062,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
|
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||||
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
@@ -1068,7 +1087,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
if self.cfg.datasets is not None and (
|
||||||
|
trainer_cls is DPOStrategy.get_trainer_class()
|
||||||
|
):
|
||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,30 +5,21 @@ module for customized trainers
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import (
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
CPOTrainer,
|
|
||||||
DPOTrainer,
|
|
||||||
KTOTrainer,
|
|
||||||
ORPOTrainer,
|
|
||||||
PRMTrainer,
|
|
||||||
RewardTrainer,
|
|
||||||
)
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
|
||||||
|
|
||||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self.optimizer = None
|
|
||||||
self.model_accepts_loss_kwargs = False
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
if loraplus_lr_ratio:
|
|
||||||
print("Using lora+")
|
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
@wraps(DPOTrainer.push_to_hub)
|
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
|
||||||
"""
|
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
|
||||||
)
|
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tokenize_row(
|
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
) -> Dict:
|
|
||||||
res = DPOTrainer.tokenize_row(
|
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
)
|
|
||||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
|
||||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
|
||||||
res[key] = res[key][1:]
|
|
||||||
|
|
||||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
|
||||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
|
||||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
|
||||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
|
||||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
|
||||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
|
||||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
|
||||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def training_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
num_items_in_batch=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
|||||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
DPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class DPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for DPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlDPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
|
||||||
|
return AxolotlDPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
if cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||||
|
if cfg.dpo_use_weighting is not None:
|
||||||
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
|
return training_args_kwargs
|
||||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl specific DPO args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import DPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
|
"""
|
||||||
|
DPO config for DPO training
|
||||||
|
"""
|
||||||
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
DPO trainer for axolotl
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import (
|
||||||
|
SchedulerMixin,
|
||||||
|
_sanitize_kwargs_for_ds_tagging,
|
||||||
|
_sanitize_kwargs_for_tagging,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self.optimizer = None
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
if loraplus_lr_ratio:
|
||||||
|
print("Using lora+")
|
||||||
|
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokenize_row(
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
) -> Dict:
|
||||||
|
res = DPOTrainer.tokenize_row(
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
)
|
||||||
|
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||||
|
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||||
|
for key in res.keys():
|
||||||
|
res[key] = res[key][1:]
|
||||||
|
|
||||||
|
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||||
|
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||||
|
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||||
|
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||||
|
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||||
|
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||||
|
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||||
|
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
num_items_in_batch=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return loss
|
||||||
113
src/axolotl/core/trainers/grpo/__init__.py
Normal file
113
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
GRPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for GRPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||||
|
|
||||||
|
return AxolotlGRPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
grpo_args_kwargs = {}
|
||||||
|
if cfg.trl and cfg.trl.use_vllm:
|
||||||
|
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||||
|
if cfg.trl and cfg.trl.vllm_device:
|
||||||
|
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||||
|
else:
|
||||||
|
grpo_args_kwargs["vllm_device"] = "auto"
|
||||||
|
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"vllm_gpu_memory_utilization"
|
||||||
|
] = cfg.trl.vllm_gpu_memory_utilization
|
||||||
|
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||||
|
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||||
|
if cfg.trl and cfg.trl.num_generations:
|
||||||
|
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||||
|
if cfg.trl and cfg.trl.sync_ref_model:
|
||||||
|
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||||
|
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"ref_model_mixup_alpha"
|
||||||
|
] = cfg.trl.ref_model_mixup_alpha
|
||||||
|
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||||
|
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||||
|
return grpo_args_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_trainer_kwargs(cls, cfg):
|
||||||
|
trainer_kwargs = {}
|
||||||
|
if cfg.trl and cfg.trl.reward_funcs:
|
||||||
|
reward_funcs = []
|
||||||
|
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||||
|
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||||
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
|
trainer_kwargs[
|
||||||
|
"reward_processing_classes"
|
||||||
|
] = cfg.trl.reward_processing_classes
|
||||||
|
return trainer_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_blocklist_args_kwargs(cls):
|
||||||
|
return ["dataset_num_proc"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||||
|
"""
|
||||||
|
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||||
|
or a HF hub path to the reward model.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the reward function does not accept at least two arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||||
|
or a path to a reward model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# use importlib to dynamically load the reward function from the module
|
||||||
|
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||||
|
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||||
|
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||||
|
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# the user has passed a string (ideally indicating the path of a reward model)
|
||||||
|
LOG.info(
|
||||||
|
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
15
src/axolotl/core/trainers/grpo/args.py
Normal file
15
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl Specific Training Args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import GRPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
|
"""
|
||||||
|
Axolotl GRPO Config for GRPO training
|
||||||
|
"""
|
||||||
14
src/axolotl/core/trainers/grpo/trainer.py
Normal file
14
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Axolotl GRPO trainer
|
||||||
|
"""
|
||||||
|
from trl import GRPOTrainer
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|
||||||
"""
|
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
|||||||
if len(strategy.split(".")) == 1:
|
if len(strategy.split(".")) == 1:
|
||||||
strategy = strategy + ".default"
|
strategy = strategy + ".default"
|
||||||
load_fn = strategy.split(".")[-1]
|
load_fn = strategy.split(".")[-1]
|
||||||
strategy = ".".join(strategy.split(".")[:-1])
|
if len(strategy.split(".")) > 1:
|
||||||
mod = importlib.import_module(f".{strategy}", module_base)
|
try:
|
||||||
|
importlib.import_module(
|
||||||
|
strategy.split(".")[-2],
|
||||||
|
".".join(strategy.split(".")[:-2]),
|
||||||
|
)
|
||||||
|
module_base = ".".join(strategy.split(".")[:-2])
|
||||||
|
strategy = strategy.split(".")[-2]
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
else:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(strategy, module_base)
|
||||||
func = getattr(mod, load_fn)
|
func = getattr(mod, load_fn)
|
||||||
return func(cfg, **kwargs)
|
return func(cfg, **kwargs)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
|||||||
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
DPO prompt strategies passthrough/zero-processing strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def default(
|
||||||
|
cfg, dataset_idx=0, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(
|
||||||
|
sample, tokenizer=None
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
|
from .trl import TrlConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
@@ -33,6 +35,7 @@ class RLType(str, Enum):
|
|||||||
"""RL trainer type configuration subset"""
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
@@ -663,14 +666,20 @@ class AxolotlInputConfig(
|
|||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
|
shrink_embeddings: Optional[bool] = None
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
|
trl: Optional[TrlConfig] = Field(
|
||||||
|
default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
|
||||||
|
)
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
process_reward_model: Optional[bool] = None
|
process_reward_model: Optional[bool] = None
|
||||||
num_labels: Optional[int] = None
|
num_labels: Optional[int] = None
|
||||||
dpo_use_weighting: Optional[
|
dpo_use_weighting: Optional[
|
||||||
bool
|
bool
|
||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
dpo_use_logits_to_keep: Optional[bool] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||||
|
|||||||
32
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
32
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TrlConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for TRL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
beta: Optional[float] = None
|
||||||
|
max_completion_length: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the completion for RL training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO specific args
|
||||||
|
use_vllm: Optional[bool] = False
|
||||||
|
vllm_device: Optional[str] = "auto"
|
||||||
|
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
|
vllm_max_model_len: Optional[int] = None
|
||||||
|
vllm_dtype: Optional[str] = "auto"
|
||||||
|
reward_funcs: Optional[List[str]] = None
|
||||||
|
num_generations: Optional[int] = None
|
||||||
|
sync_ref_model: Optional[bool] = False
|
||||||
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||||
sig = inspect.signature(ds_transform_fn)
|
sig = inspect.signature(ds_transform_fn)
|
||||||
if "tokenizer" in sig.parameters:
|
if "tokenizer" in sig.parameters:
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
|||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_set
|
return data_set
|
||||||
@@ -150,36 +151,45 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
else:
|
else:
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
elif _cfg.rl == "kto":
|
elif _cfg.rl == "kto":
|
||||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
split_datasets[i] = data_set
|
split_datasets[i] = data_set
|
||||||
|
|
||||||
drop_long = partial(
|
if not cfg.skip_prepare_dataset:
|
||||||
drop_long_rl_seq,
|
drop_long = partial(
|
||||||
rl=_cfg.rl,
|
drop_long_rl_seq,
|
||||||
tokenizer=tokenizer,
|
rl=_cfg.rl,
|
||||||
sequence_len=cfg.sequence_len,
|
tokenizer=tokenizer,
|
||||||
)
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(split_datasets[i])
|
dropped = prior_len - len(split_datasets[i])
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
|
)
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
combined_datasets = concatenate_datasets(split_datasets)
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||||
|
|||||||
@@ -1053,9 +1053,12 @@ class ModelLoader:
|
|||||||
if self.cfg.resize_token_embeddings_to_32x
|
if self.cfg.resize_token_embeddings_to_32x
|
||||||
else len(self.tokenizer)
|
else len(self.tokenizer)
|
||||||
)
|
)
|
||||||
if (
|
if hasattr(self.model, "get_input_embeddings") and (
|
||||||
hasattr(self.model, "get_input_embeddings")
|
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
or (
|
||||||
|
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||||
|
and self.cfg.shrink_embeddings
|
||||||
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
|
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"):
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
0
tests/e2e/multigpu/test_grpo.py
Normal file
0
tests/e2e/multigpu/test_grpo.py
Normal file
Reference in New Issue
Block a user