Compare commits

...

5 Commits

Author SHA1 Message Date
Wing Lian
7c5aa4791f drop position_ids for olmo model 2024-05-09 00:25:15 -04:00
Wing Lian
796a085b2f make sure to save the lora adapter at the end of RL/dpo training (#1573) 2024-05-08 10:39:33 -04:00
Wing Lian
cb78a36374 improve tool handling roles (#1587) 2024-05-07 11:30:40 -04:00
NanoCode012
8b9c15b17f feat: exclude mamba blocks for jamba (#1578) 2024-05-07 22:52:57 +09:00
Chirag Jain
9e1480e9ca Pass deepspeed and fsdp as None explicitly when merging adapters to allow custom device_map (#1575) 2024-05-07 22:47:55 +09:00
6 changed files with 70 additions and 72 deletions

View File

@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -39,76 +39,40 @@ def register_chatml_template(system_message=None):
) )
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def build_loader(
conversation = ( tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None prompter_cls: Type["ShareGPTPrompterV2"],
) default_conversation: Optional[str] = None,
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None ):
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None conversation = (
strategy = SimpleShareGPTPromptTokenizingStrategy( ds_cfg["conversation"]
ShareGPTPrompterV2( if ds_cfg and "conversation" in ds_cfg
conversation=conversation, else default_conversation
role_key_model=field_model, )
role_key_human=field_human, field_human = (
roles=roles, ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
), )
tokenizer, field_model = (
cfg.train_on_inputs, ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
cfg.sequence_len, )
) roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
if ds_cfg and "strict" in ds_cfg: strategy = tokenization_strategy_cls(
strategy.strict = ds_cfg["strict"] prompter_cls(
return strategy conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -158,7 +122,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -209,3 +175,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -348,7 +348,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -212,6 +212,10 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -1,4 +1,5 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -504,6 +505,9 @@ def load_model(
bnb_config = { bnb_config = {
"load_in_8bit": True, "load_in_8bit": True,
} }
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )

View File

@@ -197,6 +197,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "olmo":
LOG.info("dropping position_ids column")
train_dataset = train_dataset.remove_columns("position_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("position_ids")
if cfg.model_config_type == "falcon": if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names: if "token_type_ids" in train_dataset.column_names: