diff --git a/examples/starcoder2/qlora.yml b/examples/starcoder2/qlora.yml new file mode 100644 index 000000000..1efdfbc8e --- /dev/null +++ b/examples/starcoder2/qlora.yml @@ -0,0 +1,69 @@ +base_model: bigcode/starcoder2-3b + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca + + +dataset_prepared_path: +val_set_size: 0.2 +output_dir: ./qlora + +adapter: qlora +lora_model_dir: + +sequence_len: 8192 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 20 +evals_per_epoch: 4 +eval_steps: +eval_table_size: +saves_per_epoch: 4 +save_steps: +save_total_limit: 2 +debug: +deepspeed: +weight_decay: +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 65a79a878..964b41f70 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -6,7 +6,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data -SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"] +SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "mixtral", + "qwen2", + "falcon", + "phi", + "gemma", + "starcoder2", +] def patch_for_multipack(model_type): @@ -32,3 +39,7 @@ def patch_for_multipack(model_type): transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + )