Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
e1c7a61243 fix reentrant when using offloading 2025-09-14 10:42:15 -04:00
salman
9640338d37 Default include_tkps to true (#3134)
* default true

* force e2e

* causal trainer only

* fix eval loggin [skip-ci]

* revert setup.py

* force tests

* guarding

* guarding

* fix test case

* use evaluate [skip-e2e]

* use evaluate [skip-e2e]

* kick off ci

* fixing

* reverting
2025-09-09 10:50:21 -04:00
8 changed files with 49 additions and 60 deletions

View File

@@ -36,7 +36,6 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -145,12 +144,6 @@ class TrainerBuilderBase(abc.ABC):
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks

View File

@@ -39,6 +39,7 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
@@ -71,6 +72,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -342,10 +342,10 @@ class AxolotlTrainer(
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha:
return self.orpo_compute_loss(

View File

@@ -3,11 +3,14 @@ Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from functools import partial
from peft import PeftModel
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
@@ -46,9 +49,20 @@ class ActivationOffloadingMixin(Trainer):
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
if use_reentrant:
checkpoint_wrapper_fn = partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
)
else:
checkpoint_wrapper_fn = checkpoint_wrapper
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
auto_wrap_policy=auto_wrap_policy,
**kwargs,
)
def get_lora_act_offloading_ctx_manager(

View File

@@ -224,21 +224,27 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._apply_activation_checkpointing()
use_reentrant = None
if (
self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True)
):
use_reentrant = True
self._apply_activation_checkpointing(use_reentrant=use_reentrant)
self._resize_token_embeddings()
self._adjust_model_config()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self):
def _apply_activation_checkpointing(self, use_reentrant: bool | None = None):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model)
ac_wrap_hf_model(self.model, use_reentrant=use_reentrant)
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""

View File

@@ -178,38 +178,6 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
@@ -356,11 +324,10 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
cast_lora_module(module)
# module_log_bias_mismatch = _process_lora_module_for_fsdp(
# module, fsdp2_kwargs
# )
# log_bias_dtype_mismatch |= module_log_bias_mismatch
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)

View File

@@ -43,11 +43,12 @@ class TokensPerSecondCallback(TrainerCallback):
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
if hasattr(state, "num_tokens"):
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
def on_log(
self,
@@ -58,5 +59,6 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs,
): # pylint: disable=unused-argument
# after logging, clear the running metrics
state.last_tokens_per_second.zero_()
state.num_tokens = 0
if hasattr(state, "last_tokens_per_second"):
state.last_tokens_per_second.zero_()
state.num_tokens = torch.zeros(1)

View File

@@ -855,9 +855,9 @@ class AxolotlInputConfig(
},
)
include_tkps: bool | None = Field(
default=None,
default=True,
json_schema_extra={
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
"description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens."
},
)
neftune_noise_alpha: float | None = Field(