diff --git a/_quarto.yml b/_quarto.yml index 3e773a748..dab1ee363 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -268,6 +268,8 @@ website: - docs/batch_vs_grad.qmd - docs/dataset_preprocessing.qmd - docs/multipack.qmd + - docs/mixed_precision.qmd + - docs/gradient_accumulation.qmd - section: "Advanced Features" contents: diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd new file mode 100644 index 000000000..7b77cd4bb --- /dev/null +++ b/docs/mixed_precision.qmd @@ -0,0 +1,149 @@ +--- +title: "Mixed Precision Training" +format: + html: + toc: true + toc-depth: 3 + number-sections: true + code-tools: true +execute: + enabled: false +--- + +Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats: + +- **FP16** - Half precision 16-bit (Pascal generation+) +- **BF16** - Brain Float 16-bit (Ampere generation+) +- **FP8** - 8-bit floating point (Hopper generation+) + +## FP16 Mixed Precision {#sec-fp16} + +### Overview {#sec-fp16-overview} + +FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16. + +### Configuration {#sec-fp16-config} + +```{.yaml} +fp16: true +``` + +### FP16 Considerations {#sec-fp16-considerations} + +- May require gradient scaling to prevent underflow +- Less numerically stable than BF16 +- Can cause training instability with some model architectures +- Consider using BF16 if your hardware supports it + +## BF16 Mixed Precision {#sec-bf16} + +### Overview {#sec-bf16-overview} + +BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory. + +### Configuration {#sec-bf16-config} + +```{.yaml} +# Automatic BF16 detection (recommended) +bf16: auto + +# Or explicitly enable +bf16: true + +# For evaluation with BF16 +bf16: full # Equivalent to bf16_full_eval in the HF trainer +``` + +## FP8 Mixed Precision {#sec-fp8} + +::: {.callout-note} +FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO. +::: + +### What is FP8? {#sec-fp8-overview} + +FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy. + +### Requirements {#sec-fp8-software} + +- Hopper+ GPUs (H100/H200) +- PyTorch 2.7+ (+ compatible TorchAO version) +- CUDA 12.4+ + +### Configuration {#sec-fp8-config} + +Add to your YAML config: + +```{.yaml} +# Enable FP8 mixed precision +fp8: true + +# Optional: Enable FP8 for FSDP all-gather operations +fp8_enable_fsdp_float8_all_gather: true + +# Enable torch.compile (almost always necessary for FP8 speedups) +torch_compile: true +``` + +::: {.callout-important} +**torch.compile is critical for FP8 performance** + +FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16. +::: + +### Advanced FP8 Configs {#sec-fp8-advanced} + +For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training: + +```{.yaml} +fp8: true +fp8_enable_fsdp_float8_all_gather: true + +torch_compile: true + +# FSDP configuration +fsdp_version: 2 +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + state_dict_type: FULL_STATE_DICT + reshard_after_forward: true +``` + +## Best Practices {#sec-best-practices} + +### Choosing Precision Format {#sec-choosing-format} + +- **Start with automatic detection**: `bf16: auto` +- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed +- **For Ampere (A100/RTX 30/40)**: Use BF16 +- **For older Pascal/Turing GPUs**: Use FP16 with caution +- **For very old or unsupported GPUs**: Use FP32 + +### Validation and Testing {#sec-validation} + +Always validate your mixed precision setup: + +- **Start with a small dataset** to verify stability +- **Monitor loss curves** for irregularities +- **Compare with FP32 baseline** when possible +- **Test evaluation metrics** match expectations + +### FP8 Particulars {#sec-fp8-details} + +- Use cases + - Single GPU training + - Multi GPU training with FSDP2 or Deepspeed +- Speedups + - Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings + - Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks) +- Known issues: + - FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4)) + - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training + - Flash Attention 2 does not play nicely with `torch.compile` + +See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model + +For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd). diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml new file mode 100644 index 000000000..bea698c0e --- /dev/null +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -0,0 +1,76 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + +output_dir: ./outputs/fp8_out/ + +sample_packing: true +pad_to_sequence_len: true +sequence_len: 512 + +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +torch_compile: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true + +fp8: true +fp8_enable_fsdp_float8_all_gather: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_steps: 10 +weight_decay: 0.0 + +fsdp_version: 2 +fsdp_config: + offload_params: false + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: false + +special_tokens: + pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index b983f1076..3dfaf47ce 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -7,7 +7,7 @@ from __future__ import annotations import os from collections import defaultdict from functools import partial, wraps -from typing import Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional import datasets import torch @@ -522,15 +522,25 @@ class AxolotlTrainer( return res + # pylint: disable=unused-argument def additional_accelerator_args( - self, fp8=None, **kwargs - ): # pylint: disable=unused-argument + self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs + ) -> dict[str, Any]: ret_kwargs = {} if fp8: from accelerate.utils import AORecipeKwargs + from torchao.float8 import Float8LinearConfig + + # By default, Float8LinearConfig is instantiated using the "tensorwise" + # scaling strategy. See more details here: + # https://github.com/pytorch/ao/tree/main/torchao/float8. + config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True, + ) ret_kwargs["mixed_precision"] = "fp8" - ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()] + ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" return ret_kwargs diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f346c56e0..533bd0f7a 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -154,7 +154,9 @@ class PatchManager: patch_create_accelerate_code_for_fp8, ) - patch_create_accelerate_code_for_fp8() + patch_create_accelerate_code_for_fp8( + self.cfg.fp8_enable_fsdp_float8_all_gather + ) def _apply_flash_attention_peft_patches(self): """Apply patches for Flash Attention with PEFT.""" diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index 0a5b27c13..819a66255 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -18,7 +18,7 @@ ORIGINAL_TRAINER_CODE = """ PATCHED_TRAINER_CODE = """ if hasattr(self, "additional_accelerator_args"): - additional_args = self.additional_accelerator_args(fp8=True, **args) + additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args) if additional_args: args.update(additional_args) @@ -38,9 +38,9 @@ def check_create_accelerate_code_is_patchable() -> bool: return ORIGINAL_TRAINER_CODE in create_code -def patch_create_accelerate_code_for_fp8(): +def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool): """ - monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs. """ try: @@ -54,7 +54,10 @@ def patch_create_accelerate_code_for_fp8(): if ORIGINAL_TRAINER_CODE not in create_code: return - create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + patched_trainer_code = PATCHED_TRAINER_CODE.format( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather + ) + create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code) create_code = create_code.replace( "def create_accelerator_and_postprocess(", "def fixed_create_accelerator_and_postprocess(", diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index de928d11c..96b694043 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -343,7 +343,20 @@ class AxolotlInputConfig( fp16: bool | None = Field( default=None, json_schema_extra={"description": "Use CUDA fp16"} ) - fp8: bool | None = None + fp8: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable FP8 mixed precision training using TorchAO. Best " + "used in combination with torch.compile." + }, + ) + fp8_enable_fsdp_float8_all_gather: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable FSDP float8 all-gather optimization for FP8 training. Can " + "improve training speed by 10-15% when FSDP is enabled." + }, + ) bfloat16: bool | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 64dbb2529..0c1a97fcd 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -360,6 +360,36 @@ class TrainingValidationMixin: # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half return self + @model_validator(mode="before") + @classmethod + def check_fp8_config(cls, data): + if data.get("fp8") and not data.get("torch_compile"): + LOG.warning( + "torch_compile is strongly recommended for FP8 training in order to " + "see speed improvements. Please consider setting `torch_compile: " + "true` in your config." + ) + if data.get("fp8") and ( + data.get("fsdp_config", {}).get("activation_checkpointing", False) is True + or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False) + is True + ): + LOG.warning( + "FP8 + FSDP2 + activation checkpointing may be slower than BF16 " + "training. Please considering setting `activation_checkpointing: false` " + "in your FSDP config." + ) + if ( + data.get("fp8_enable_fsdp_float8_all_gather") + and not data.get("fsdp_version", None) == 2 + ): + raise ValueError( + "fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) " + "to be used." + ) + + return data + @model_validator(mode="before") @classmethod def check_use_reentrant_mismatch(cls, data): diff --git a/tests/e2e/integrations/test_fp8.py b/tests/e2e/integrations/test_fp8.py new file mode 100644 index 000000000..0302b7e35 --- /dev/null +++ b/tests/e2e/integrations/test_fp8.py @@ -0,0 +1,62 @@ +""" +Simple end-to-end smoke tests for FP8 mixed precision training +""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists, require_torch_2_7_0 + + +class FP8IntegrationTestCase: + """ + e2e smoke tests for FP8 mixed precision training with Axolotl + """ + + @require_torch_2_7_0 + def test_fp8_single_gpu_smoke(self, temp_dir): + """Smoke test for single GPU FP8 + torch.compile training""" + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "sdp_attention": True, + "pad_to_seq_len": True, + "sample_packing": True, + "fp8": True, + "torch_compile": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + # pylint: disable=duplicate-code + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/multigpu/test_fp8_fsdp2.py b/tests/e2e/multigpu/test_fp8_fsdp2.py new file mode 100644 index 000000000..6423f5e2e --- /dev/null +++ b/tests/e2e/multigpu/test_fp8_fsdp2.py @@ -0,0 +1,120 @@ +"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality.""" + +# pylint: disable=duplicate-code + +import os +from pathlib import Path + +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_fp8_training_success(temp_dir): + """Verify that FP8 training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert ( + len(checkpoint_files) > 0 + ), "No checkpoint files found - training may have failed" + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan( + torch.tensor(final_loss) + ), f"Training loss is NaN: {final_loss}" + + +class TestFP8FSDP2: + """Test class for FP8 mixed precision with FSDP2 functionality.""" + + @require_torch_2_7_0 + def test_fp8_fsdp2_smoke(self, temp_dir): + """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", # Use standard optimizer for stability + "lr_scheduler": "cosine", + "sdp_attention": True, + "pad_to_seq_len": True, + "sample_packing": True, + # FP8 configuration + "fp8": True, + "fp8_enable_fsdp_float8_all_gather": True, + "torch_compile": True, + # FSDP2 configuration + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_fp8_training_success(temp_dir) diff --git a/tests/monkeypatch/test_trainer_accelerator_args.py b/tests/monkeypatch/test_trainer_accelerator_args.py new file mode 100644 index 000000000..fab2597f0 --- /dev/null +++ b/tests/monkeypatch/test_trainer_accelerator_args.py @@ -0,0 +1,26 @@ +""" +Unit tests for trainer accelerator args monkeypatch +""" + +import unittest + +from axolotl.monkeypatch.trainer_accelerator_args import ( + check_create_accelerate_code_is_patchable, +) + + +class TestTrainerAcceleratorArgs(unittest.TestCase): + """ + Unit test class for trainer accelerator args monkeypatch + """ + + def test_check_create_accelerate_code_is_patchable(self): + """ + Test that the upstream transformers code is still patchable. + This will fail if the patched code changes upstream. + """ + assert check_create_accelerate_code_is_patchable() + + +if __name__ == "__main__": + unittest.main()