From 83d4d97dccd4acf98a77f7c4c32d4a6f32a1a064 Mon Sep 17 00:00:00 2001 From: salman Date: Wed, 17 Dec 2025 15:35:22 +0100 Subject: [PATCH] Add QAT NVFP4 configs for blogpost (#3280) [skip ci] * add configs for blogpost * fix configs * fixing baseline configs --- examples/qat_nvfp4/Gemma3-12B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Gemma3-12B_qat.yml | 72 ++++++++++++++++++ .../qat_nvfp4/Math-Gemma3-12B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Math-Gemma3-12B_qat.yml | 72 ++++++++++++++++++ .../qat_nvfp4/Math-Gemma3-27B_baseline.yml | 68 +++++++++++++++++ examples/qat_nvfp4/Math-Gemma3-27B_qat.yml | 73 +++++++++++++++++++ .../qat_nvfp4/Math-Qwen2.5-72B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml | 72 ++++++++++++++++++ examples/qat_nvfp4/Qwen2.5-72B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Qwen2.5-72B_qat.yml | 72 ++++++++++++++++++ 10 files changed, 697 insertions(+) create mode 100644 examples/qat_nvfp4/Gemma3-12B_baseline.yml create mode 100644 examples/qat_nvfp4/Gemma3-12B_qat.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-12B_qat.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-27B_qat.yml create mode 100644 examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml create mode 100644 examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml create mode 100644 examples/qat_nvfp4/Qwen2.5-72B_baseline.yml create mode 100644 examples/qat_nvfp4/Qwen2.5-72B_qat.yml diff --git a/examples/qat_nvfp4/Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Gemma3-12B_baseline.yml new file mode 100644 index 000000000..be4e86635 --- /dev/null +++ b/examples/qat_nvfp4/Gemma3-12B_baseline.yml @@ -0,0 +1,67 @@ +base_model: google/gemma-3-12b-it +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/out_gemma/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 4e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Gemma3-12B_qat.yml b/examples/qat_nvfp4/Gemma3-12B_qat.yml new file mode 100644 index 000000000..7fa81163f --- /dev/null +++ b/examples/qat_nvfp4/Gemma3-12B_qat.yml @@ -0,0 +1,72 @@ +base_model: google/gemma-3-12b-it +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/qat_out_gemma/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 4e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml new file mode 100644 index 000000000..9f209515b --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml @@ -0,0 +1,67 @@ +base_model: google/gemma-3-12b-it +# Math finetuning configuration for Gemma3-12B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/out_math_gemma/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 3e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml new file mode 100644 index 000000000..ef7e754be --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml @@ -0,0 +1,72 @@ +base_model: google/gemma-3-12b-it +# Math finetuning configuration for Gemma3-12B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/qat_out_math_gemma/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 3e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml new file mode 100644 index 000000000..3a262d342 --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml @@ -0,0 +1,68 @@ +base_model: google/gemma-3-27b-it +# Math finetuning configuration for Gemma3-27B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/out_math_gemma27/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml new file mode 100644 index 000000000..87016ae9c --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml @@ -0,0 +1,73 @@ +base_model: google/gemma-3-27b-it +# Math finetuning configuration for Gemma3-27B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/qat_out_math_gemma27/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml b/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml new file mode 100644 index 000000000..efec25c54 --- /dev/null +++ b/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml @@ -0,0 +1,67 @@ +base_model: Qwen/Qwen2.5-72B +# Math finetuning configuration for Qwen2.5-72B (non-instruct) +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/out_math_72b/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml b/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml new file mode 100644 index 000000000..427d7af52 --- /dev/null +++ b/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen2.5-72B +# Math finetuning configuration for Qwen2.5-72B (non-instruct) +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/qat_out_math_72b/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml b/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml new file mode 100644 index 000000000..e1eaba61f --- /dev/null +++ b/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml @@ -0,0 +1,67 @@ +base_model: Qwen/Qwen2.5-72B +# Alpaca finetuning configuration for Qwen2.5-72B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/out_qwen72b/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Qwen2.5-72B_qat.yml b/examples/qat_nvfp4/Qwen2.5-72B_qat.yml new file mode 100644 index 000000000..dad7e5422 --- /dev/null +++ b/examples/qat_nvfp4/Qwen2.5-72B_qat.yml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen2.5-72B +# Alpaca finetuning configuration for Qwen2.5-72B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/qat_out_qwen72b/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config