diff --git a/README.md b/README.md index f57765290..6a4e9f8f3 100644 --- a/README.md +++ b/README.md @@ -464,8 +464,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod ```yaml load_in_4bit: true load_in_8bit: true - bf16: true # require >=ampere - fp16: true + bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically. + fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32 tf32: true # require >=ampere bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision) float16: true # use instead of fp16 when you don't want AMP diff --git a/examples/cerebras/btlm-ft.yml b/examples/cerebras/btlm-ft.yml index d0975214b..18dd86e6b 100644 --- a/examples/cerebras/btlm-ft.yml +++ b/examples/cerebras/btlm-ft.yml @@ -53,8 +53,8 @@ lr_quadratic_warmup: true learning_rate: 0.000085 train_on_inputs: true group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: false diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml index 03155c6c2..e26aab848 100644 --- a/examples/cerebras/qlora.yml +++ b/examples/cerebras/qlora.yml @@ -36,8 +36,8 @@ lr_scheduler: cosine learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: true early_stopping_patience: diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index 9c0df0afa..e4ffd0684 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -41,8 +41,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/code-llama/13b/qlora.yml b/examples/code-llama/13b/qlora.yml index 06b9ac72f..78ffd28ed 100644 --- a/examples/code-llama/13b/qlora.yml +++ b/examples/code-llama/13b/qlora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index a137d54e7..664c30884 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -41,8 +41,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/code-llama/34b/qlora.yml b/examples/code-llama/34b/qlora.yml index ad1e21675..ca9b14eaf 100644 --- a/examples/code-llama/34b/qlora.yml +++ b/examples/code-llama/34b/qlora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index 217b2a635..9f0613ede 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -41,8 +41,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/code-llama/7b/qlora.yml b/examples/code-llama/7b/qlora.yml index 12462dcb7..0dc485e7e 100644 --- a/examples/code-llama/7b/qlora.yml +++ b/examples/code-llama/7b/qlora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 13bad9425..7cdbb6cef 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -38,8 +38,8 @@ lr_scheduler: cosine learning_rate: 0.00003 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: true early_stopping_patience: diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index a89124bb8..d93806dfc 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -64,8 +64,8 @@ lr_scheduler: cosine learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: true # stop training after this many evaluation losses have increased in a row diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index ff37dcf85..722ab0740 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -38,8 +38,8 @@ lr_scheduler: cosine learning_rate: 0.00003 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: true early_stopping_patience: diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml index 700d10e67..cd3f2e2ad 100644 --- a/examples/gptj/qlora.yml +++ b/examples/gptj/qlora.yml @@ -33,8 +33,8 @@ lr_scheduler: cosine learning_rate: 0.0001 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: true early_stopping_patience: diff --git a/examples/jeopardy-bot/config.yml b/examples/jeopardy-bot/config.yml index ac8814b0b..a672c7b94 100644 --- a/examples/jeopardy-bot/config.yml +++ b/examples/jeopardy-bot/config.yml @@ -31,7 +31,7 @@ lr_scheduler: cosine learning_rate: 0.00003 train_on_inputs: false group_by_length: false -bf16: true +bf16: auto tf32: true early_stopping_patience: resume_from_checkpoint: diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 5530283bf..e2388ec67 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -41,8 +41,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index abe1c1de0..61d82a403 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -41,8 +41,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index d68882d6a..41810d56d 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index ff76ddbea..60bd56638 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -47,8 +47,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 946bbe731..9b697892a 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -34,8 +34,8 @@ learning_rate: 5e-5 train_on_inputs: false group_by_length: true -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: false diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index ea62e9ebf..df7047867 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -34,8 +34,8 @@ learning_rate: 0.000005 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 11c842d4e..cb14c6745 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -63,8 +63,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 35c79ebf4..44ab5691b 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -50,8 +50,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml index bc36b1c60..45e31266f 100644 --- a/examples/mpt-7b/config.yml +++ b/examples/mpt-7b/config.yml @@ -33,7 +33,7 @@ lr_scheduler: cosine learning_rate: 0.0000002 train_on_inputs: false group_by_length: false -bf16: true +bf16: auto tf32: true early_stopping_patience: resume_from_checkpoint: diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index eaebd21ef..cab280c2a 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -46,8 +46,8 @@ learning_rate: 0.000003 train_on_inputs: false group_by_length: true -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 691a83509..bb0ff40be 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -46,8 +46,8 @@ learning_rate: 0.000003 train_on_inputs: false group_by_length: true -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index 81df94170..af146ae64 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -49,8 +49,8 @@ learning_rate: 1e-5 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: true gradient_checkpointing: true diff --git a/examples/pythia/lora.yml b/examples/pythia/lora.yml index 10c76c973..7cb07fe25 100644 --- a/examples/pythia/lora.yml +++ b/examples/pythia/lora.yml @@ -27,7 +27,7 @@ num_epochs: 4 learning_rate: 0.00001 train_on_inputs: false group_by_length: false -bf16: true +bf16: auto tf32: true early_stopping_patience: resume_from_checkpoint: diff --git a/examples/qwen/lora.yml b/examples/qwen/lora.yml index 0ad9fc0f1..c14e5f8d6 100644 --- a/examples/qwen/lora.yml +++ b/examples/qwen/lora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: false diff --git a/examples/qwen/qlora.yml b/examples/qwen/qlora.yml index 1ce0cbdc0..cb3666d25 100644 --- a/examples/qwen/qlora.yml +++ b/examples/qwen/qlora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: false diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml index a369b6cef..5a42e2a95 100644 --- a/examples/redpajama/config-3b.yml +++ b/examples/redpajama/config-3b.yml @@ -34,7 +34,7 @@ lr_scheduler: cosine learning_rate: 0.0000002 train_on_inputs: false group_by_length: false -bf16: true +bf16: auto tf32: true early_stopping_patience: resume_from_checkpoint: diff --git a/examples/replit-3b/config-lora.yml b/examples/replit-3b/config-lora.yml index 01314acc1..bdfe1bd85 100644 --- a/examples/replit-3b/config-lora.yml +++ b/examples/replit-3b/config-lora.yml @@ -33,7 +33,7 @@ lr_scheduler: learning_rate: 0.00001 train_on_inputs: false group_by_length: false -bf16: true +bf16: auto tf32: true gradient_checkpointing: early_stopping_patience: diff --git a/examples/tiny-llama/lora.yml b/examples/tiny-llama/lora.yml index 53d50178a..67930dacf 100644 --- a/examples/tiny-llama/lora.yml +++ b/examples/tiny-llama/lora.yml @@ -41,8 +41,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index dfd1bfca2..065a32a22 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -34,8 +34,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/tiny-llama/qlora.yml b/examples/tiny-llama/qlora.yml index 53791985e..66860ee33 100644 --- a/examples/tiny-llama/qlora.yml +++ b/examples/tiny-llama/qlora.yml @@ -43,8 +43,8 @@ learning_rate: 0.0002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/xgen-7b/xgen-7b-8k-qlora.yml index 48924e5f7..e3faa01bd 100644 --- a/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/xgen-7b/xgen-7b-8k-qlora.yml @@ -62,8 +62,8 @@ lr_scheduler: cosine learning_rate: 0.00002 train_on_inputs: false group_by_length: false -bf16: true -fp16: false +bf16: auto +fp16: tf32: false gradient_checkpointing: true # stop training after this many evaluation losses have increased in a row diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml index 0c1a4b788..fedbc26b7 100644 --- a/examples/yi-34B-chat/qlora.yml +++ b/examples/yi-34B-chat/qlora.yml @@ -7,8 +7,8 @@ load_in_8bit: false load_in_4bit: true strict: false sequence_len: 1024 -bf16: true -fp16: false +bf16: auto +fp16: tf32: false flash_attention: true special_tokens: diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index b04564841..d6f144912 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -70,6 +70,8 @@ def normalize_config(cfg): else: LOG.debug("bf16 support not detected, disabling for this configuration.") cfg.bf16 = False + if cfg.fp16 is None: + cfg.fp16 = True if cfg.device == "mps": cfg.load_in_8bit = False @@ -79,6 +81,8 @@ def normalize_config(cfg): cfg.bf16 = False else: torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False + if cfg.bf16: + cfg.fp16 = False if cfg.bf16 or cfg.bfloat16: cfg.torch_dtype = torch.bfloat16 diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index da039f6cd..9d7573ff0 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -78,13 +78,28 @@ class NormalizeConfigTestCase(unittest.TestCase): normalize_config(cfg) self.assertTrue(cfg.bf16) + self.assertFalse(cfg.fp16) @patch("axolotl.utils.config.is_torch_bf16_gpu_available") def test_bf16_auto_setter_not_available(self, mock_bf16_avail): cfg = self._get_base_cfg() cfg.bf16 = "auto" + cfg.fp16 = None mock_bf16_avail.return_value = False normalize_config(cfg) self.assertFalse(cfg.bf16) + self.assertTrue(cfg.fp16) + + @patch("axolotl.utils.config.is_torch_bf16_gpu_available") + def test_bf16_disables_fp16(self, mock_bf16_avail): + cfg = self._get_base_cfg() + cfg.bf16 = True + cfg.fp16 = False + mock_bf16_avail.return_value = True + + normalize_config(cfg) + + self.assertTrue(cfg.bf16) + self.assertFalse(cfg.fp16)