diff --git a/docs/faq.qmd b/docs/faq.qmd index 57c2c81e6..08d439af7 100644 --- a/docs/faq.qmd +++ b/docs/faq.qmd @@ -136,3 +136,7 @@ description: Frequently asked questions > dynamic: false > mode: max-autotune-no-cudagraphs > ``` + +**Q: `ValueError("Backward pass should have cleared tracker of all tensors")` + +> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML. diff --git a/src/axolotl/core/trainers/mixins/activation_checkpointing.py b/src/axolotl/core/trainers/mixins/activation_checkpointing.py index 9488186cd..1bfdb49f7 100644 --- a/src/axolotl/core/trainers/mixins/activation_checkpointing.py +++ b/src/axolotl/core/trainers/mixins/activation_checkpointing.py @@ -4,13 +4,22 @@ Trainer mixin for activation checkpointing w offloading import contextlib +from peft import PeftModel from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from transformers import GradientCheckpointingLayer, Trainer -from trl.models.activation_offloading import get_act_offloading_ctx_manager +from trl.models.activation_offloading import ( + NoOpManager, + OffloadActivations, + get_act_offloading_ctx_manager, +) + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class ActivationOffloadingMixin(Trainer): @@ -21,9 +30,14 @@ class ActivationOffloadingMixin(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.args.activation_offloading: - self.activation_offload_context = get_act_offloading_ctx_manager( - self.model, use_streams=True - ) + if isinstance(self.model, PeftModel): + self.activation_offload_context = get_lora_act_offloading_ctx_manager( + self.model, use_streams=True + ) + else: + self.activation_offload_context = get_act_offloading_ctx_manager( + self.model, use_streams=True + ) else: self.activation_offload_context = contextlib.nullcontext() @@ -35,3 +49,169 @@ class ActivationOffloadingMixin(Trainer): def ac_wrap_hf_model(model: nn.Module, **kwargs): auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,))) apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs) + + +def get_lora_act_offloading_ctx_manager( + model: nn.Module, + use_pin_memory: bool = True, + use_streams: bool = True, + min_offload_size: int = 1024, + max_fwd_stash_size: int = 5, + warn_if_no_head: bool = True, +) -> OffloadActivations: + """ + Returns the activation offloading context manager for the model. All but the last output Linear in every step will + be offloaded. + + If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is + disabled, we return a NoOpManager context manager. + + Args: + model (`nn.Module`): + Model to wrap with the activation offloading context manager. + use_pin_memory (`bool`, *optional*, defaults to `True`): + Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to + be moved back onto GPU more quickly but is a limited resource. + use_streams (`bool`, *optional*, defaults to `True`): + Whether to use streams for performance optimization where the communications get overlapped with the + computation. Requires a torch build after torch-2.5.0. + min_offload_size (`int`, *optional*, defaults to `1024`): + Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we + do not want to waste bandwidth and resources moving it to CPU and back. + max_fwd_stash_size (`int`, *optional*, defaults to `5`): + Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during + the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow + more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping + alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing + runtime. + warn_if_no_head (`bool`, *optional*, defaults to `True`): + Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output + head is detected. + + Returns: + `contextlib.ContextDecorator`: + Activation offloading context manager for the model. + """ + # pylint: disable=unnecessary-dunder-call + activations_handling_ctx = OffloadActivations( + use_pin_memory=use_pin_memory, + use_streams=use_streams, + min_offload_size=min_offload_size, + max_fwd_stash_size=max_fwd_stash_size, + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. + output_head_detected = False + noop_ctx = NoOpManager() + + # Try to get the actual model if it's wrapped + unwrapped_model = model + if hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + # check for PEFT models + if hasattr(unwrapped_model, "base_model") and hasattr( + unwrapped_model, "peft_config" + ): + unwrapped_model = unwrapped_model.base_model + + # Check for different types of output heads + if hasattr(unwrapped_model, "output"): + if isinstance(unwrapped_model.output, nn.Module): + unwrapped_model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + elif hasattr(unwrapped_model.output, "linear") and isinstance( + unwrapped_model.output.linear, nn.Module + ): + unwrapped_model.output.linear.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.output.linear.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for HuggingFace model output heads + elif hasattr(unwrapped_model, "lm_head"): + unwrapped_model.lm_head.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.lm_head.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for decoder-based models + elif hasattr(unwrapped_model, "decoder"): + decoder = unwrapped_model.decoder + if hasattr(decoder, "output"): + decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + decoder.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + # Some models have lm_head in the decoder + elif hasattr(decoder, "lm_head"): + decoder.lm_head.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + decoder.lm_head.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for transformer models with final layer norm + elif hasattr(unwrapped_model, "final_layer_norm") or hasattr( + unwrapped_model, "ln_f" + ): + final_norm = ( + getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f + ) + final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + final_norm.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for models with head module + elif hasattr(unwrapped_model, "head") and isinstance( + unwrapped_model.head, nn.Module + ): + unwrapped_model.head.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.head.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + if not output_head_detected and warn_if_no_head: + LOG.warning( + "During activation offloading, no output head was detected. If your model has an output head, it will be " + "offloaded. This usually greatly slows training, given the large vocabulary size. To change this " + "behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by " + "passing `warn_if_no_head=False`." + ) + + for name, module in unwrapped_model.named_modules(): + # Disable offloading for any Liger modules + if "liger" in name.lower(): + module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + module.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + # disable offloading for any submodules to fix LoRA training + if name.endswith("._checkpoint_wrapped_module"): + for _, sub_module in module.named_modules(): + sub_module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + sub_module.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + + return activations_handling_ctx diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 0c1a97fcd..cfa759cad 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1104,16 +1104,10 @@ class ModelCompatibilityValidationMixin: "`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`" ) self.gradient_checkpointing = True - if self.adapter and "lora" in self.adapter: - LOG.warning( - "offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation." - ) - self.activation_offloading = "legacy" - else: - LOG.warning( - "`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`" - ) - self.activation_offloading = True + LOG.warning( + "`offload` now uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`" + ) + self.activation_offloading = True if self.gradient_checkpointing == "offload_disk": LOG.warning( "`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`" @@ -1122,19 +1116,6 @@ class ModelCompatibilityValidationMixin: self.activation_offloading = "disk" return self - @model_validator(mode="after") - def check_activation_offloading_w_lora(self): - if ( - self.activation_offloading is True - and self.adapter - and "lora" in self.adapter - ): - LOG.warning( - "activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`" - ) - self.activation_offloading = "legacy" - return self - @model_validator(mode="after") def check_activation_offloading_wo_gc(self): if self.activation_offloading and not self.gradient_checkpointing: diff --git a/tests/e2e/test_activation_offloading.py b/tests/e2e/test_activation_offloading.py new file mode 100644 index 000000000..06c5c0656 --- /dev/null +++ b/tests/e2e/test_activation_offloading.py @@ -0,0 +1,83 @@ +""" +E2E tests for activation offloading +""" + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists + +# pylint: disable=duplicate-code + + +class TestActivationOffloading: + """ + E2E test cases for activation offloading + """ + + @pytest.mark.parametrize( + "adapter", + ["lora", "qlora", None], + ) + def test_activation_offloading( + self, + temp_dir, + adapter, + ): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + "eos_token": "<|im_end|>", + }, + "datasets": [ + { + "chat_template": "chatml", + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": "auto", + "save_safetensors": True, + "gradient_checkpointing": True, + "activation_offloading": True, + "save_first_step": False, + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + } + ) + if adapter == "lora": + cfg["adapter"] = "lora" + if adapter == "qlora": + cfg["adapter"] = "qlora" + cfg["load_in_4bit"] = True + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/utils/schemas/validation/test_activation_offloading.py b/tests/utils/schemas/validation/test_activation_offloading.py index 92ac8f45c..433133a80 100644 --- a/tests/utils/schemas/validation/test_activation_offloading.py +++ b/tests/utils/schemas/validation/test_activation_offloading.py @@ -21,62 +21,6 @@ class TestActivationOffloading: assert cfg.gradient_checkpointing is True assert cfg.activation_offloading is True - def test_gc_converts_offload_w_lora(self, min_base_cfg): - cfg = ( - DictDefault( - gradient_checkpointing="offload", - adapter="lora", - ) - | min_base_cfg - ) - - cfg = validate_config(cfg) - assert cfg.gradient_checkpointing is True - assert cfg.activation_offloading == "legacy" - - def test_gc_converts_offload_w_qlora(self, min_base_cfg): - cfg = ( - DictDefault( - gradient_checkpointing="offload", - adapter="qlora", - load_in_4bit=True, - ) - | min_base_cfg - ) - - cfg = validate_config(cfg) - assert cfg.gradient_checkpointing is True - assert cfg.activation_offloading == "legacy" - - def test_ac_impl_changes_w_lora(self, min_base_cfg): - cfg = ( - DictDefault( - gradient_checkpointing=True, - activation_offloading=True, - adapter="lora", - ) - | min_base_cfg - ) - - cfg = validate_config(cfg) - assert cfg.gradient_checkpointing is True - assert cfg.activation_offloading == "legacy" - - def test_ac_impl_changes_w_qlora(self, min_base_cfg): - cfg = ( - DictDefault( - gradient_checkpointing=True, - activation_offloading=True, - adapter="qlora", - load_in_4bit=True, - ) - | min_base_cfg - ) - - cfg = validate_config(cfg) - assert cfg.gradient_checkpointing is True - assert cfg.activation_offloading == "legacy" - def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg): cfg = ( DictDefault(