Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
317761406e add support for NCA 2024-05-06 17:01:14 -04:00
Wing Lian
6a9ac4ad27 consistency w sppo -> sppo_hard 2024-05-06 16:58:58 -04:00
Wing Lian
027f7d54f0 update for sppo 2024-05-06 16:55:46 -04:00
Wing Lian
0554105baa add mistral instruct strategy and fix dpo_loss input 2024-05-06 16:55:18 -04:00
Wing Lian
f58fcd09ec use DPOConfig 2024-05-06 16:55:16 -04:00
Wing Lian
60fecac367 bump trl 2024-05-06 16:54:03 -04:00
Wing Lian
b301068098 remove override 2024-05-06 16:54:02 -04:00
Wing Lian
df645906eb invert check 2024-05-06 16:54:02 -04:00
Wing Lian
7fea5822f0 add support for SPPO 2024-05-06 16:54:02 -04:00
11 changed files with 120 additions and 78 deletions

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair' # use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl==0.8.5 trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -25,8 +25,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )

View File

@@ -30,7 +30,7 @@ from transformers import (
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
training_args_cls = DPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args = training_args_cls( training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1550,8 +1553,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["loss_type"] = "ipo" dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing: if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair": elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
dpo_trainer_kwargs["loss_type"] = "kto_pair" dpo_trainer_kwargs["loss_type"] = self.cfg.rl
if self.eval_dataset: if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
@@ -1560,7 +1563,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[ dpo_trainer_kwargs[
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]

View File

@@ -0,0 +1,30 @@
"""
DPO strategies for mistral instruct
"""
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
sample["chosen"] = f"{sample['chosen']}"
sample["rejected"] = f"{sample['rejected']}"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
return sample
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional, Type from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -39,40 +39,76 @@ def register_chatml_template(system_message=None):
) )
def build_loader( def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], conversation = (
prompter_cls: Type["ShareGPTPrompterV2"], ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
default_conversation: Optional[str] = None, )
): field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
conversation = ( roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
ds_cfg["conversation"] strategy = SimpleShareGPTPromptTokenizingStrategy(
if ds_cfg and "conversation" in ds_cfg ShareGPTPrompterV2(
else default_conversation conversation=conversation,
) role_key_model=field_model,
field_human = ( role_key_human=field_human,
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None roles=roles,
) ),
field_model = ( tokenizer,
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None cfg.train_on_inputs,
) cfg.sequence_len,
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None )
strategy = tokenization_strategy_cls( if ds_cfg and "strict" in ds_cfg:
prompter_cls( strategy.strict = ds_cfg["strict"]
conversation=conversation, return strategy
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -122,9 +158,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy( class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -175,16 +209,3 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -348,10 +348,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if ( LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -212,10 +212,6 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -133,6 +133,8 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
nca_pair = "nca_pair" # pylint: disable=invalid-name
class ChatTemplate(str, Enum): class ChatTemplate(str, Enum):
@@ -574,6 +576,7 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None orpo_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
max_memory: Optional[ max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]

View File

@@ -1,5 +1,4 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -505,9 +504,6 @@ def load_model(
bnb_config = { bnb_config = {
"load_in_8bit": True, "load_in_8bit": True,
} }
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )
@@ -793,7 +789,11 @@ def load_model(
if not reference_model or cfg.lora_model_dir: if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training # if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]
and not cfg.merge_lora
):
_, lora_config = load_lora(model, cfg, inference=False, config_only=True) _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else: else:
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)

View File

@@ -197,12 +197,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "olmo":
LOG.info("dropping position_ids column")
train_dataset = train_dataset.remove_columns("position_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("position_ids")
if cfg.model_config_type == "falcon": if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names: if "token_type_ids" in train_dataset.column_names:
@@ -444,7 +438,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard", "nca_pair"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2] trainer_builder.peft_config = model[2]