support for gemma2 w sample packing (#1718)
This commit is contained in:
68
examples/gemma2/qlora.yml
Normal file
68
examples/gemma2/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: google/gemma-2-9b
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: gemma
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
chat_template: gemma
|
||||||
|
drop_system_message: true
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.41.1
|
transformers==4.42.3
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.30.1
|
accelerate==0.30.1
|
||||||
|
|||||||
@@ -1091,6 +1091,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
else:
|
else:
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
if warmup_steps == 1:
|
||||||
|
warmup_steps = 2
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
CrossEntropyLoss, inplace_backward=True
|
CrossEntropyLoss, inplace_backward=True
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.info(
|
LOG.warning(
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.info(
|
LOG.warning(
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -826,7 +826,6 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
padding_mask=padding_mask,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ def flashattn_forward(
|
|||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
@@ -422,6 +422,9 @@ def mistral_model_forward(
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[ # pylint: disable=unused-argument
|
||||||
|
torch.LongTensor
|
||||||
|
] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"gemma",
|
"gemma",
|
||||||
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
@@ -49,6 +50,10 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
elif model_type == "gemma2":
|
||||||
|
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "starcoder2":
|
elif model_type == "starcoder2":
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
message_field_content: str = "value",
|
message_field_content: str = "value",
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
@@ -39,6 +40,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
turns = [
|
turns = [
|
||||||
@@ -49,6 +51,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
for t in conversation
|
for t in conversation
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if self.drop_system_message and turns[0]["role"] == "system":
|
||||||
|
turns = turns[1:]
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@@ -111,6 +116,11 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
else "value"
|
else "value"
|
||||||
)
|
)
|
||||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
drop_system_message = (
|
||||||
|
ds_cfg["drop_system_message"]
|
||||||
|
if ds_cfg and "drop_system_message" in ds_cfg
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
@@ -119,6 +129,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
message_field_role=message_field_role,
|
message_field_role=message_field_role,
|
||||||
message_field_content=message_field_content,
|
message_field_content=message_field_content,
|
||||||
roles=roles,
|
roles=roles,
|
||||||
|
drop_system_message=drop_system_message,
|
||||||
),
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class SFTDataset(BaseModel):
|
|||||||
message_field_content: Optional[str] = None
|
message_field_content: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
|
drop_system_message: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="FIXME?")
|
||||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using S2 Attn
|
Test case for Llama models using S2 Attn
|
||||||
|
|||||||
Reference in New Issue
Block a user