diff --git a/examples/pixtral/lora-12b.yaml b/examples/pixtral/lora-12b.yaml new file mode 100644 index 000000000..88fb02be5 --- /dev/null +++ b/examples/pixtral/lora-12b.yaml @@ -0,0 +1,63 @@ +base_model: mistral-community/pixtral-12b +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c18af9760..b6e2a3e1d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1114,7 +1114,8 @@ def load_lora(model, cfg, inference=False, config_only=False): fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, bias="none", - task_type="CAUSAL_LM", + # task_type="CAUSAL_LM", + task_type="CONDITIONAL_GENERATION" if cfg.is_multimodal else "CAUSAL_LM", **lora_config_kwargs, )