Compare commits
1 Commits
reentrant-
...
6daed7d060
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6daed7d060 |
@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
|
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
@@ -144,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.cfg.include_tkps:
|
||||||
|
callbacks.append(
|
||||||
|
TokensPerSecondCallback(
|
||||||
|
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from axolotl.utils.collators import (
|
|||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -72,12 +71,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.qat:
|
if self.cfg.qat:
|
||||||
callbacks.append(QATCallback(self.cfg.qat))
|
callbacks.append(QATCallback(self.cfg.qat))
|
||||||
|
|
||||||
if self.cfg.include_tkps:
|
|
||||||
callbacks.append(
|
|
||||||
TokensPerSecondCallback(
|
|
||||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
|||||||
@@ -342,10 +342,10 @@ class AxolotlTrainer(
|
|||||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||||
if hasattr(self.state, "num_tokens"):
|
if hasattr(self.state, "num_tokens"):
|
||||||
self.state.num_tokens = (
|
self.state.num_tokens = (
|
||||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
|
|||||||
@@ -3,14 +3,11 @@ Trainer mixin for activation checkpointing w offloading
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
apply_activation_checkpointing,
|
apply_activation_checkpointing,
|
||||||
checkpoint_wrapper,
|
|
||||||
CheckpointImpl,
|
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||||
from transformers import GradientCheckpointingLayer, Trainer
|
from transformers import GradientCheckpointingLayer, Trainer
|
||||||
@@ -49,20 +46,9 @@ class ActivationOffloadingMixin(Trainer):
|
|||||||
return super().training_step(*args, **kwargs)
|
return super().training_step(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
|
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
||||||
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
||||||
if use_reentrant:
|
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
||||||
checkpoint_wrapper_fn = partial(
|
|
||||||
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
checkpoint_wrapper_fn = checkpoint_wrapper
|
|
||||||
apply_activation_checkpointing(
|
|
||||||
model,
|
|
||||||
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_act_offloading_ctx_manager(
|
def get_lora_act_offloading_ctx_manager(
|
||||||
|
|||||||
@@ -224,27 +224,21 @@ class ModelLoader:
|
|||||||
):
|
):
|
||||||
self.model = self.model.merge_and_unload()
|
self.model = self.model.merge_and_unload()
|
||||||
|
|
||||||
use_reentrant = None
|
self._apply_activation_checkpointing()
|
||||||
if (
|
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
|
||||||
and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True)
|
|
||||||
):
|
|
||||||
use_reentrant = True
|
|
||||||
self._apply_activation_checkpointing(use_reentrant=use_reentrant)
|
|
||||||
self._resize_token_embeddings()
|
self._resize_token_embeddings()
|
||||||
self._adjust_model_config()
|
self._adjust_model_config()
|
||||||
self._configure_embedding_dtypes()
|
self._configure_embedding_dtypes()
|
||||||
self._configure_qat()
|
self._configure_qat()
|
||||||
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
||||||
|
|
||||||
def _apply_activation_checkpointing(self, use_reentrant: bool | None = None):
|
def _apply_activation_checkpointing(self):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
from axolotl.core.trainers.mixins.activation_checkpointing import (
|
from axolotl.core.trainers.mixins.activation_checkpointing import (
|
||||||
ac_wrap_hf_model,
|
ac_wrap_hf_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ^^ importing this at the module level breaks plugins
|
# ^^ importing this at the module level breaks plugins
|
||||||
ac_wrap_hf_model(self.model, use_reentrant=use_reentrant)
|
ac_wrap_hf_model(self.model)
|
||||||
|
|
||||||
def _resize_token_embeddings(self):
|
def _resize_token_embeddings(self):
|
||||||
"""Resize token embeddings if needed."""
|
"""Resize token embeddings if needed."""
|
||||||
|
|||||||
@@ -178,6 +178,38 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def cast_lora_module(module):
|
||||||
|
base_layer_dtype = module.base_layer.weight.dtype
|
||||||
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
|
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||||
|
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
|
||||||
|
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||||
|
log_bias_dtype_mismatch = True
|
||||||
|
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||||
|
module.base_layer.weight.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for active_adapter in module.active_adapters:
|
||||||
|
if module.lora_A:
|
||||||
|
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
|
||||||
|
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_B:
|
||||||
|
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
|
||||||
|
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_embedding_A:
|
||||||
|
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
|
||||||
|
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_embedding_B:
|
||||||
|
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
|
||||||
|
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_magnitude_vector:
|
||||||
|
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
|
||||||
|
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
@@ -324,10 +356,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
for module in get_module_children_bottom_up(model)[:-1]:
|
for module in get_module_children_bottom_up(model)[:-1]:
|
||||||
if is_peft_model and isinstance(module, LoraLayer):
|
if is_peft_model and isinstance(module, LoraLayer):
|
||||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
cast_lora_module(module)
|
||||||
module, fsdp2_kwargs
|
# module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||||
)
|
# module, fsdp2_kwargs
|
||||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
# )
|
||||||
|
# log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||||
fully_shard(module, **fsdp2_kwargs)
|
fully_shard(module, **fsdp2_kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -43,12 +43,11 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if hasattr(state, "num_tokens"):
|
step_time = time.perf_counter() - self.start_time
|
||||||
step_time = time.perf_counter() - self.start_time
|
num_tokens_per_device = state.num_tokens.clone()
|
||||||
num_tokens_per_device = state.num_tokens.clone()
|
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
|
||||||
|
|
||||||
def on_log(
|
def on_log(
|
||||||
self,
|
self,
|
||||||
@@ -59,6 +58,5 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# after logging, clear the running metrics
|
# after logging, clear the running metrics
|
||||||
if hasattr(state, "last_tokens_per_second"):
|
state.last_tokens_per_second.zero_()
|
||||||
state.last_tokens_per_second.zero_()
|
state.num_tokens = 0
|
||||||
state.num_tokens = torch.zeros(1)
|
|
||||||
|
|||||||
@@ -855,9 +855,9 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
include_tkps: bool | None = Field(
|
include_tkps: bool | None = Field(
|
||||||
default=True,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens."
|
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
neftune_noise_alpha: float | None = Field(
|
neftune_noise_alpha: float | None = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user