multipack for gemma (#1313)
* multipack for gemma * chore: lint * handle cache_position kwarg in updated llama modeling * add position_ids to rotary embed call for updated llama modeling
This commit is contained in:
@@ -1,49 +1,49 @@
|
||||
# use google/gemma-7b if you have access
|
||||
base_model: mhenrichsen/gemma-7b
|
||||
base_model: mhenrichsen/gemma-7b
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
|
||||
# huggingface repo
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
val_set_size: 0.1
|
||||
output_dir: ./out
|
||||
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
|
||||
|
||||
gradient_accumulation_steps: 3
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
@@ -51,7 +51,7 @@ local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate==0.26.1
|
||||
|
||||
@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
cos, sin = self.rotary_emb(
|
||||
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
||||
)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
@@ -425,7 +427,9 @@ def flashattn_forward(
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
cos, sin = self.rotary_emb(
|
||||
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
||||
)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
@@ -688,6 +692,9 @@ def llama_model_forward(
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[ # pylint: disable=unused-argument
|
||||
torch.LongTensor
|
||||
] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
|
||||
@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user