Compare commits
1 Commits
activation
...
fix/dpo-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc1900761b |
@@ -15,7 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
|||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.evaluate import evaluate
|
from axolotl.evaluate import evaluate
|
||||||
from axolotl.utils import patch_optimized_env
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
cli_args: CLI arguments.
|
cli_args: CLI arguments.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from axolotl.cli.utils import (
|
|||||||
filter_none_kwargs,
|
filter_none_kwargs,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import patch_optimized_env
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -55,8 +55,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
patch_optimized_env()
|
|
||||||
|
|
||||||
if cloud:
|
if cloud:
|
||||||
from axolotl.cli.cloud import do_cli_preprocess
|
from axolotl.cli.cloud import do_cli_preprocess
|
||||||
|
|
||||||
@@ -102,7 +100,7 @@ def train(
|
|||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||||
accelerate = False
|
accelerate = False
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import patch_optimized_env
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
|
|||||||
@@ -610,15 +610,3 @@ class AxolotlTrainer(
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
def compute_loss_context_manager(self):
|
|
||||||
from contextlib import ExitStack
|
|
||||||
|
|
||||||
from torchtune.training import OffloadActivations
|
|
||||||
|
|
||||||
stack = ExitStack()
|
|
||||||
|
|
||||||
stack.enter_context(super().compute_loss_context_manager())
|
|
||||||
stack.enter_context(OffloadActivations())
|
|
||||||
|
|
||||||
return stack
|
|
||||||
|
|||||||
@@ -151,30 +151,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
elif cfg.model_config_type == "qwen3":
|
|
||||||
from axolotl.integrations.liger.models.qwen3 import (
|
|
||||||
apply_liger_kernel_to_qwen3,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_moe(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
"""
|
|
||||||
Liger FLCE for Qwen3. Based on transformers v4.51.3.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
|
|
||||||
def lce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Cache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
# if in training mode, don't materialize logits
|
|
||||||
if self.training and (labels is not None):
|
|
||||||
loss = LigerForCausalLMLoss(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
lm_head_weight=self.lm_head.weight,
|
|
||||||
labels=labels,
|
|
||||||
hidden_size=self.config.hidden_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if in inference mode materialize logits
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel_to_qwen3(
|
|
||||||
cross_entropy: bool = False,
|
|
||||||
fused_linear_cross_entropy: bool = False,
|
|
||||||
rms_norm: bool = False,
|
|
||||||
glu_activation: bool = False,
|
|
||||||
layer_norm: bool = False,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> None:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
"""
|
|
||||||
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
||||||
fused_linear_cross_entropy (bool):
|
|
||||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
||||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
|
||||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
||||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
||||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
||||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
assert not (
|
|
||||||
cross_entropy and fused_linear_cross_entropy
|
|
||||||
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
||||||
|
|
||||||
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
|
||||||
|
|
||||||
if glu_activation:
|
|
||||||
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if fused_linear_cross_entropy:
|
|
||||||
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
"""
|
|
||||||
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
||||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
|
|
||||||
|
|
||||||
|
|
||||||
def lce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
# if in training mode, don't materialize logits
|
|
||||||
if self.training and (labels is not None):
|
|
||||||
loss = LigerForCausalLMLoss(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
lm_head_weight=self.lm_head.weight,
|
|
||||||
labels=labels,
|
|
||||||
hidden_size=self.config.hidden_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if in inference mode materialize logits
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to(
|
|
||||||
loss.device
|
|
||||||
) # make sure to reside in the same device
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel_to_qwen3_moe(
|
|
||||||
cross_entropy: bool = False,
|
|
||||||
fused_linear_cross_entropy: bool = False,
|
|
||||||
rms_norm: bool = False,
|
|
||||||
glu_activation: bool = False,
|
|
||||||
layer_norm: bool = False,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> None:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
"""
|
|
||||||
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
||||||
fused_linear_cross_entropy (bool):
|
|
||||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
||||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
|
||||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
||||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
||||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
||||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
assert not (
|
|
||||||
cross_entropy and fused_linear_cross_entropy
|
|
||||||
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
||||||
|
|
||||||
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
||||||
|
|
||||||
if glu_activation:
|
|
||||||
|
|
||||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
|
||||||
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
|
||||||
# clone config to avoid modifying the original
|
|
||||||
config = deepcopy(config)
|
|
||||||
if intermediate_size:
|
|
||||||
setattr(config, "intermediate_size", intermediate_size)
|
|
||||||
return LigerSwiGLUMLP(config, **kwargs)
|
|
||||||
|
|
||||||
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if fused_linear_cross_entropy:
|
|
||||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward
|
|
||||||
@@ -18,8 +18,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"qwen3",
|
|
||||||
"qwen3_moe",
|
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
|
|||||||
@@ -43,12 +43,3 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
||||||
"expandable_segments:True,roundup_power2_divisions:16"
|
"expandable_segments:True,roundup_power2_divisions:16"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_optimized_env():
|
|
||||||
"""
|
|
||||||
Patch environment variables to improve VRAM usage and increase download speed
|
|
||||||
"""
|
|
||||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
||||||
set_pytorch_cuda_alloc_conf()
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
def resolve_dtype(cfg):
|
def resolve_dtype(cfg):
|
||||||
if (
|
if (
|
||||||
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
|
cfg.bf16 == "auto" and not cfg.use_ray
|
||||||
): # if we use ray we want to defer this check to the worker node
|
): # if we use ray we want to defer this check to the worker node
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||||
|
|||||||
@@ -2,13 +2,6 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.checkpoint import (
|
|
||||||
CheckpointPolicy,
|
|
||||||
checkpoint,
|
|
||||||
create_selective_checkpoint_contexts,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
@@ -25,32 +18,3 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
),
|
),
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
compute_intensive_ops = [
|
|
||||||
aten.mm.default,
|
|
||||||
aten.bmm.default,
|
|
||||||
aten.addmm.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def policy_fn(ctx, op, *args, **kwargs):
|
|
||||||
if op in compute_intensive_ops:
|
|
||||||
return CheckpointPolicy.MUST_SAVE
|
|
||||||
else:
|
|
||||||
return CheckpointPolicy.PREFER_RECOMPUTE
|
|
||||||
|
|
||||||
|
|
||||||
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
|
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_w_policy(
|
|
||||||
decoder_layer, *args, use_reentrant=None
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
return checkpoint(
|
|
||||||
decoder_layer,
|
|
||||||
*args,
|
|
||||||
use_reentrant=use_reentrant,
|
|
||||||
context_fn=context_fn,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.len_across_ranks = None
|
self.len_across_ranks = None
|
||||||
|
|
||||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||||
LOG.warning(
|
LOG.warn(
|
||||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user