support for gemma2 w sample packing (#1718)
This commit is contained in:
@@ -1091,6 +1091,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
|
||||
@@ -112,7 +112,7 @@ def replace_llama_attn_with_flash_attn(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
LOG.warning(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
@@ -130,7 +130,7 @@ def replace_llama_attn_with_flash_attn(
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
LOG.warning(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
@@ -826,7 +826,6 @@ def llama_model_forward(
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
@@ -145,7 +145,7 @@ def flashattn_forward(
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
@@ -422,6 +422,9 @@ def mistral_model_forward(
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[ # pylint: disable=unused-argument
|
||||
torch.LongTensor
|
||||
] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
|
||||
@@ -16,6 +16,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
"deepseek_v2",
|
||||
@@ -49,6 +50,10 @@ def patch_for_multipack(model_type, model_name=None):
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
|
||||
@@ -23,6 +23,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_field_role: str = "from",
|
||||
message_field_content: str = "value",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
@@ -39,6 +40,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||
turns = [
|
||||
@@ -49,6 +51,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
for t in conversation
|
||||
]
|
||||
|
||||
if self.drop_system_message and turns[0]["role"] == "system":
|
||||
turns = turns[1:]
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
turns,
|
||||
truncation=True,
|
||||
@@ -111,6 +116,11 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
else "value"
|
||||
)
|
||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||
drop_system_message = (
|
||||
ds_cfg["drop_system_message"]
|
||||
if ds_cfg and "drop_system_message" in ds_cfg
|
||||
else False
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
@@ -119,6 +129,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
message_field_role=message_field_role,
|
||||
message_field_content=message_field_content,
|
||||
roles=roles,
|
||||
drop_system_message=drop_system_message,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
|
||||
@@ -116,6 +116,7 @@ class SFTDataset(BaseModel):
|
||||
message_field_content: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
drop_system_message: Optional[bool] = None
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user