support for gemma2 w sample packing (#1718)

This commit is contained in:
Wing Lian
2024-06-29 01:38:55 -04:00
committed by GitHub
parent f2480a1d91
commit 5370cedf0c
9 changed files with 97 additions and 5 deletions

68
examples/gemma2/qlora.yml Normal file
View File

@@ -0,0 +1,68 @@
base_model: google/gemma-2-9b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.41.1
transformers==4.42.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.30.1

View File

@@ -1091,6 +1091,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)
if warmup_steps == 1:
warmup_steps = 2
logging_steps = (
self.cfg.logging_steps

View File

@@ -112,7 +112,7 @@ def replace_llama_attn_with_flash_attn(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
LOG.warning(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
@@ -130,7 +130,7 @@ def replace_llama_attn_with_flash_attn(
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
@@ -826,7 +826,6 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

View File

@@ -145,7 +145,7 @@ def flashattn_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -422,6 +422,9 @@ def mistral_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions

View File

@@ -16,6 +16,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"falcon",
"phi",
"gemma",
"gemma2",
"gemmoe",
"starcoder2",
"deepseek_v2",
@@ -49,6 +50,10 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data

View File

@@ -23,6 +23,7 @@ class ChatTemplatePrompter(Prompter):
message_field_role: str = "from",
message_field_content: str = "value",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
@@ -39,6 +40,7 @@ class ChatTemplatePrompter(Prompter):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
@@ -49,6 +51,9 @@ class ChatTemplatePrompter(Prompter):
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -111,6 +116,11 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
drop_system_message = (
ds_cfg["drop_system_message"]
if ds_cfg and "drop_system_message" in ds_cfg
else False
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
@@ -119,6 +129,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
drop_system_message=drop_system_message,
),
tokenizer,
cfg.train_on_inputs,

View File

@@ -116,6 +116,7 @@ class SFTDataset(BaseModel):
message_field_content: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
class UserDefinedDPOType(BaseModel):

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="FIXME?")
class TestLlamaShiftedSparseAttention(unittest.TestCase):
"""
Test case for Llama models using S2 Attn