Add QAT NVFP4 configs for blogpost (#3280) [skip ci]
* add configs for blogpost * fix configs * fixing baseline configs
This commit is contained in:
67
examples/qat_nvfp4/Gemma3-12B_baseline.yml
Normal file
67
examples/qat_nvfp4/Gemma3-12B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: google/gemma-3-12b-it
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/out_gemma/
|
||||||
|
|
||||||
|
sequence_len: 8096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 16
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 4e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
72
examples/qat_nvfp4/Gemma3-12B_qat.yml
Normal file
72
examples/qat_nvfp4/Gemma3-12B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: google/gemma-3-12b-it
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out_gemma/
|
||||||
|
|
||||||
|
sequence_len: 8096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: nvfp4
|
||||||
|
weight_dtype: nvfp4
|
||||||
|
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 16
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 4e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
67
examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml
Normal file
67
examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: google/gemma-3-12b-it
|
||||||
|
# Math finetuning configuration for Gemma3-12B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: AI-MO/NuminaMath-CoT
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
output_dir: ./outputs/out_math_gemma/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 8
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 3e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
72
examples/qat_nvfp4/Math-Gemma3-12B_qat.yml
Normal file
72
examples/qat_nvfp4/Math-Gemma3-12B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: google/gemma-3-12b-it
|
||||||
|
# Math finetuning configuration for Gemma3-12B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: AI-MO/NuminaMath-CoT
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out_math_gemma/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: nvfp4
|
||||||
|
weight_dtype: nvfp4
|
||||||
|
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 8
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 3e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
68
examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml
Normal file
68
examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: google/gemma-3-27b-it
|
||||||
|
# Math finetuning configuration for Gemma3-27B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: AI-MO/NuminaMath-CoT
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
output_dir: ./outputs/out_math_gemma27/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 16
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-6
|
||||||
|
eta_min: 7e-7
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
73
examples/qat_nvfp4/Math-Gemma3-27B_qat.yml
Normal file
73
examples/qat_nvfp4/Math-Gemma3-27B_qat.yml
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
base_model: google/gemma-3-27b-it
|
||||||
|
# Math finetuning configuration for Gemma3-27B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: AI-MO/NuminaMath-CoT
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out_math_gemma27/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: nvfp4
|
||||||
|
weight_dtype: nvfp4
|
||||||
|
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 16
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-6
|
||||||
|
eta_min: 7e-7
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
67
examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml
Normal file
67
examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-72B
|
||||||
|
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: qwen_25
|
||||||
|
datasets:
|
||||||
|
- path: AI-MO/NuminaMath-CoT
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
output_dir: ./outputs/out_math_72b/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 8
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-6
|
||||||
|
eta_min: 7e-7
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
72
examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml
Normal file
72
examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-72B
|
||||||
|
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: qwen_25
|
||||||
|
datasets:
|
||||||
|
- path: AI-MO/NuminaMath-CoT
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out_math_72b/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: nvfp4
|
||||||
|
weight_dtype: nvfp4
|
||||||
|
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 8
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-6
|
||||||
|
eta_min: 7e-7
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
67
examples/qat_nvfp4/Qwen2.5-72B_baseline.yml
Normal file
67
examples/qat_nvfp4/Qwen2.5-72B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-72B
|
||||||
|
# Alpaca finetuning configuration for Qwen2.5-72B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: qwen_25
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/out_qwen72b/
|
||||||
|
|
||||||
|
sequence_len: 8096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 16
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
72
examples/qat_nvfp4/Qwen2.5-72B_qat.yml
Normal file
72
examples/qat_nvfp4/Qwen2.5-72B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-72B
|
||||||
|
# Alpaca finetuning configuration for Qwen2.5-72B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
seed: 42
|
||||||
|
chat_template: qwen_25
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out_qwen72b/
|
||||||
|
|
||||||
|
sequence_len: 8096
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: nvfp4
|
||||||
|
weight_dtype: nvfp4
|
||||||
|
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||||
|
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 16
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp_version: 2
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
Reference in New Issue
Block a user