Compare commits
9 Commits
olmo-no-po
...
nca-pair
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
317761406e | ||
|
|
6a9ac4ad27 | ||
|
|
027f7d54f0 | ||
|
|
0554105baa | ||
|
|
f58fcd09ec | ||
|
|
60fecac367 | ||
|
|
b301068098 | ||
|
|
df645906eb | ||
|
|
7fea5822f0 |
@@ -138,7 +138,7 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
|
||||
rl:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
|
||||
@@ -39,6 +39,6 @@ s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl==0.8.5
|
||||
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -25,8 +25,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||
training_args_cls = DPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
training_args = training_args_cls(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1550,8 +1553,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
elif self.cfg.rl == "kto_pair":
|
||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||
elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
|
||||
dpo_trainer_kwargs["loss_type"] = self.cfg.rl
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
@@ -1560,7 +1563,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
DPO strategies for mistral instruct
|
||||
"""
|
||||
|
||||
|
||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||
sample["chosen"] = f"{sample['chosen']}"
|
||||
sample["rejected"] = f"{sample['rejected']}"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
@@ -39,40 +39,76 @@ def register_chatml_template(system_message=None):
|
||||
)
|
||||
|
||||
|
||||
def build_loader(
|
||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||
default_conversation: Optional[str] = None,
|
||||
):
|
||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else default_conversation
|
||||
)
|
||||
field_human = (
|
||||
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
)
|
||||
field_model = (
|
||||
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
)
|
||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||
strategy = tokenization_strategy_cls(
|
||||
prompter_cls(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
|
||||
return _load
|
||||
|
||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_guanaco(tokenizer, cfg):
|
||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else "chatml_glaive"
|
||||
)
|
||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(conversation=conversation),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
@@ -122,9 +158,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
SimpleShareGPTPromptTokenizingStrategy
|
||||
):
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
@@ -175,16 +209,3 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
|
||||
conversation = merge_consecutive_messages(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_ultrachat = build_loader(
|
||||
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||
)
|
||||
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_glaive = build_loader(
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
ShareGPTPrompterV2,
|
||||
default_conversation="chatml_glaive",
|
||||
)
|
||||
|
||||
@@ -348,10 +348,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
)
|
||||
|
||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||
if (
|
||||
role != "assistant"
|
||||
): # back to back assistant calls may be okay for tool calls
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
|
||||
@@ -212,10 +212,6 @@ def train(
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
|
||||
@@ -133,6 +133,8 @@ class RLType(str, Enum):
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
||||
nca_pair = "nca_pair" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
@@ -574,6 +576,7 @@ class AxolotlInputConfig(
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
orpo_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module for models and model loading"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
@@ -505,9 +504,6 @@ def load_model(
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
# Exclude mamba blocks from int8 quantization for jamba
|
||||
if cfg.model_config_type == "jamba":
|
||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
@@ -793,7 +789,11 @@ def load_model(
|
||||
if not reference_model or cfg.lora_model_dir:
|
||||
# if we're not loading the reference model, then we're loading the model for training
|
||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
||||
if (
|
||||
cfg.adapter
|
||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]
|
||||
and not cfg.merge_lora
|
||||
):
|
||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
@@ -197,12 +197,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
|
||||
if cfg.model_config_type == "olmo":
|
||||
LOG.info("dropping position_ids column")
|
||||
train_dataset = train_dataset.remove_columns("position_ids")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("position_ids")
|
||||
|
||||
if cfg.model_config_type == "falcon":
|
||||
LOG.info("dropping token_type_ids column if it exists")
|
||||
if "token_type_ids" in train_dataset.column_names:
|
||||
@@ -444,7 +438,7 @@ def prepare_optim_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard", "nca_pair"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
|
||||
Reference in New Issue
Block a user