Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
7c5aa4791f drop position_ids for olmo model 2024-05-09 00:25:15 -04:00
Wing Lian
796a085b2f make sure to save the lora adapter at the end of RL/dpo training (#1573) 2024-05-08 10:39:33 -04:00
Wing Lian
cb78a36374 improve tool handling roles (#1587) 2024-05-07 11:30:40 -04:00
NanoCode012
8b9c15b17f feat: exclude mamba blocks for jamba (#1578) 2024-05-07 22:52:57 +09:00
Chirag Jain
9e1480e9ca Pass deepspeed and fsdp as None explicitly when merging adapters to allow custom device_map (#1575) 2024-05-07 22:47:55 +09:00
11 changed files with 76 additions and 119 deletions

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard' # use RL training: 'dpo', 'ipo', 'kto_pair'
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c trl==0.8.5
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )

View File

@@ -30,7 +30,7 @@ from transformers import (
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -1526,9 +1526,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
training_args_cls = DPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args = training_args_cls( training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1555,8 +1552,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair": elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair" dpo_trainer_kwargs["loss_type"] = "kto_pair"
elif self.cfg.rl == "sppo_hard":
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
if self.eval_dataset: if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
@@ -1565,7 +1560,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[ dpo_trainer_kwargs[
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]

View File

@@ -1,30 +0,0 @@
"""
DPO strategies for mistral instruct
"""
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
sample["chosen"] = f"{sample['chosen']}"
sample["rejected"] = f"{sample['rejected']}"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
return sample
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -39,76 +39,40 @@ def register_chatml_template(system_message=None):
) )
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def build_loader(
conversation = ( tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None prompter_cls: Type["ShareGPTPrompterV2"],
) default_conversation: Optional[str] = None,
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None ):
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None conversation = (
strategy = SimpleShareGPTPromptTokenizingStrategy( ds_cfg["conversation"]
ShareGPTPrompterV2( if ds_cfg and "conversation" in ds_cfg
conversation=conversation, else default_conversation
role_key_model=field_model, )
role_key_human=field_human, field_human = (
roles=roles, ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
), )
tokenizer, field_model = (
cfg.train_on_inputs, ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
cfg.sequence_len, )
) roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
if ds_cfg and "strict" in ds_cfg: strategy = tokenization_strategy_cls(
strategy.strict = ds_cfg["strict"] prompter_cls(
return strategy conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -158,7 +122,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -209,3 +175,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -348,7 +348,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -212,6 +212,10 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -133,7 +133,6 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
class ChatTemplate(str, Enum): class ChatTemplate(str, Enum):
@@ -575,7 +574,6 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None orpo_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
max_memory: Optional[ max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]

View File

@@ -1,4 +1,5 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -504,6 +505,9 @@ def load_model(
bnb_config = { bnb_config = {
"load_in_8bit": True, "load_in_8bit": True,
} }
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )
@@ -789,11 +793,7 @@ def load_model(
if not reference_model or cfg.lora_model_dir: if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training # if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if ( if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
and not cfg.merge_lora
):
_, lora_config = load_lora(model, cfg, inference=False, config_only=True) _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else: else:
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)

View File

@@ -197,6 +197,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "olmo":
LOG.info("dropping position_ids column")
train_dataset = train_dataset.remove_columns("position_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("position_ids")
if cfg.model_config_type == "falcon": if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names: if "token_type_ids" in train_dataset.column_names:
@@ -438,7 +444,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]: if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2] trainer_builder.peft_config = model[2]