TiledMLP support for FSDP2 (#2950)
* make TiledMLP work with FSDP * cleanup/gc at start of train to prevent large VRAM spike * chore: lint * generic function for non-deepspeed training * unify patch to fix imports * update readme for ALST and add examples * make deepspeed attribute on params check more robust * update with new info from PR review
This commit is contained in:
@@ -25,6 +25,7 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
|
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
|
|||||||
9
examples/alst/README.md
Normal file
9
examples/alst/README.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# Arctic Long Sequence Training (ALST)
|
||||||
|
|
||||||
|
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
|
||||||
|
techniques. It is a combination of:
|
||||||
|
- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
|
||||||
|
- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
|
||||||
|
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||||
|
|
||||||
|
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||||
53
examples/alst/llama3-8b-deepspeed-alst.yaml
Normal file
53
examples/alst/llama3-8b-deepspeed-alst.yaml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
base_model: meta-llama/Llama-3.1-8B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: togethercomputer/Long-Data-Collections
|
||||||
|
type: completion
|
||||||
|
field: text
|
||||||
|
data_files:
|
||||||
|
- pretrain/rp_sub.jsonl.zst
|
||||||
|
- path: princeton-nlp/TextbookChapters
|
||||||
|
type: completion
|
||||||
|
field: chapter
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 500_000
|
||||||
|
min_sample_len: 200_000
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
tiled_mlp: true
|
||||||
|
sequence_parallel_degree: 8
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: legacy
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
saves_per_epoch: 1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
59
examples/alst/llama3-8b-fsdp2-alst.yaml
Normal file
59
examples/alst/llama3-8b-fsdp2-alst.yaml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
base_model: meta-llama/Llama-3.1-8B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: togethercomputer/Long-Data-Collections
|
||||||
|
type: completion
|
||||||
|
field: text
|
||||||
|
data_files:
|
||||||
|
- pretrain/rp_sub.jsonl.zst
|
||||||
|
- path: princeton-nlp/TextbookChapters
|
||||||
|
type: completion
|
||||||
|
field: chapter
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 500_000
|
||||||
|
min_sample_len: 200_000
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
tiled_mlp: true
|
||||||
|
context_parallel_size: 8
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: legacy
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
saves_per_epoch: 1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false # offloading is currently not compatible with SP + torchao optimizer
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_tiled_mlp_conflict(cls, data):
|
def check_tiled_mlp_conflict(cls, data):
|
||||||
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
|
if (
|
||||||
|
data.get("liger_glu_activation") is True
|
||||||
|
and data.get("tiled_mlp") is True
|
||||||
|
and not data.get("tiled_mlp_use_original_mlp")
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
|
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ class PatchManager:
|
|||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_gemma3_conditional_generation_forward_patch()
|
self._apply_gemma3_conditional_generation_forward_patch()
|
||||||
self._apply_sequence_parallel_patches()
|
self._apply_sequence_parallel_patches()
|
||||||
|
|
||||||
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
@@ -272,7 +275,9 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
from axolotl.monkeypatch.tiled_mlp import (
|
||||||
|
patch_tiled_mlp,
|
||||||
|
)
|
||||||
|
|
||||||
patch_tiled_mlp(
|
patch_tiled_mlp(
|
||||||
model_type,
|
model_type,
|
||||||
|
|||||||
11
src/axolotl/monkeypatch/tiled_mlp/__init__.py
Normal file
11
src/axolotl/monkeypatch/tiled_mlp/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
TiledMLP monkey patches
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .patch import (
|
||||||
|
patch_tiled_mlp,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"patch_tiled_mlp",
|
||||||
|
]
|
||||||
153
src/axolotl/monkeypatch/tiled_mlp/base.py
Normal file
153
src/axolotl/monkeypatch/tiled_mlp/base.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
TiledMLP support for DDP, FSDP, and single GPU
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TiledMLP(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
TiledMLP implementation using gradient hooks
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
fn,
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
shards,
|
||||||
|
compute_params,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
ctx.fn = fn
|
||||||
|
ctx.self = self
|
||||||
|
ctx.shards = shards
|
||||||
|
ctx.compute_params = [p for p in compute_params if p.requires_grad]
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
|
||||||
|
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||||
|
with torch.no_grad():
|
||||||
|
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||||
|
output_unsharded = torch.cat(output_shards, dim=1)
|
||||||
|
|
||||||
|
return output_unsharded
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *grads) -> torch.Tensor:
|
||||||
|
fn = ctx.fn
|
||||||
|
(x,) = ctx.saved_tensors
|
||||||
|
self = ctx.self
|
||||||
|
shards = ctx.shards
|
||||||
|
compute_params = ctx.compute_params
|
||||||
|
|
||||||
|
x_requires_grad = x.requires_grad
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad_(x_requires_grad)
|
||||||
|
|
||||||
|
incoming_grad = grads[0]
|
||||||
|
x_grad = torch.zeros_like(x)
|
||||||
|
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||||
|
|
||||||
|
# Create a gradient accumulator for parameters
|
||||||
|
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
|
||||||
|
|
||||||
|
shard_step = x_shards[0].numel()
|
||||||
|
for i, x_shard in enumerate(x_shards):
|
||||||
|
x_shard.requires_grad_(x_requires_grad)
|
||||||
|
|
||||||
|
shard_offset = i * shard_step
|
||||||
|
x_shard.grad = (
|
||||||
|
x_grad.view(-1)
|
||||||
|
.narrow(0, shard_offset, x_shard.numel())
|
||||||
|
.view_as(x_shard)
|
||||||
|
)
|
||||||
|
incoming_grad_shard = (
|
||||||
|
incoming_grad.view(-1)
|
||||||
|
.narrow(0, shard_offset, x_shard.numel())
|
||||||
|
.view_as(x_shard)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install hooks for this shard
|
||||||
|
is_last_shard = i + 1 == shards
|
||||||
|
grad_accumulator.install_hooks(is_last_shard)
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = fn(self, x_shard)
|
||||||
|
torch.autograd.backward(output, incoming_grad_shard)
|
||||||
|
|
||||||
|
# Clean up hooks
|
||||||
|
grad_accumulator.cleanup()
|
||||||
|
del grad_accumulator
|
||||||
|
|
||||||
|
return (None, None, x_grad, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientAccumulator:
|
||||||
|
"""
|
||||||
|
Manual gradient accumulator for TiledMLP with configurable precision
|
||||||
|
Accumulates in specified dtype and rescales the gradient at the end
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: List[torch.nn.Parameter],
|
||||||
|
total_shards: int,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
):
|
||||||
|
self.params = params
|
||||||
|
self.total_shards = total_shards
|
||||||
|
self.grad_accumulation_dtype = dtype or torch.float32
|
||||||
|
self.accumulated_grads = {}
|
||||||
|
self.hooks = []
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.gradient_scale = 1.0 / total_shards
|
||||||
|
|
||||||
|
# Initialize accumulated gradients in the specified dtype
|
||||||
|
for param in self.params:
|
||||||
|
if param.grad is not None:
|
||||||
|
self.accumulated_grads[param] = param.grad.to(
|
||||||
|
self.grad_accumulation_dtype
|
||||||
|
)
|
||||||
|
param.grad = None
|
||||||
|
else:
|
||||||
|
self.accumulated_grads[param] = torch.zeros_like(
|
||||||
|
param, dtype=self.grad_accumulation_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def install_hooks(self, is_last_shard: bool):
|
||||||
|
"""Install gradient hooks that accumulate gradients in higher precision"""
|
||||||
|
|
||||||
|
def create_hook(param):
|
||||||
|
def hook(grad):
|
||||||
|
with self.lock:
|
||||||
|
grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
|
||||||
|
scaled_grad = grad_to_accum_dtype * self.gradient_scale
|
||||||
|
|
||||||
|
if param in self.accumulated_grads:
|
||||||
|
self.accumulated_grads[param] += scaled_grad
|
||||||
|
else:
|
||||||
|
self.accumulated_grads[param] = scaled_grad.clone()
|
||||||
|
|
||||||
|
# Only assign the averaged gradient on the last shard
|
||||||
|
if is_last_shard:
|
||||||
|
param.grad = self.accumulated_grads[param].to(param.dtype)
|
||||||
|
return param.grad
|
||||||
|
return None
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
# Install hooks on all parameters
|
||||||
|
for param in self.params:
|
||||||
|
if param.requires_grad:
|
||||||
|
hook = param.register_hook(create_hook(param))
|
||||||
|
self.hooks.append(hook)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Remove all installed hooks"""
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.remove()
|
||||||
|
self.hooks.clear()
|
||||||
|
del self.accumulated_grads
|
||||||
@@ -12,8 +12,12 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
from deepspeed.runtime.sequence_parallel.ulysses_sp import (
|
||||||
|
TiledMLP as DeepSpeedTiledMLP,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and MLP class
|
# Dynamically import the module and MLP class
|
||||||
@@ -36,6 +40,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
||||||
|
|
||||||
def tiled_mlp_forward(self, x):
|
def tiled_mlp_forward(self, x):
|
||||||
|
# pylint: disable=protected-access
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
seqlen = input_shape[-2]
|
seqlen = input_shape[-2]
|
||||||
hidden = input_shape[-1]
|
hidden = input_shape[-1]
|
||||||
@@ -48,14 +53,23 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
else:
|
else:
|
||||||
num_shards = cfg_num_shards
|
num_shards = cfg_num_shards
|
||||||
|
|
||||||
if not self._compute_params: # pylint: disable=protected-access
|
if not self._compute_params:
|
||||||
self._compute_params = [ # pylint: disable=protected-access
|
self._compute_params = [p for p in self.parameters() if p.requires_grad]
|
||||||
p for p in self.parameters() if p.requires_grad
|
|
||||||
]
|
|
||||||
|
|
||||||
compute_params = self._compute_params # pylint: disable=protected-access
|
compute_params = self._compute_params
|
||||||
|
if not self._tiled_mlp_dist_impl:
|
||||||
|
if (
|
||||||
|
self._compute_params
|
||||||
|
and any(
|
||||||
|
hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group")
|
||||||
|
for p in self._compute_params
|
||||||
|
)
|
||||||
|
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||||
|
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
||||||
|
else:
|
||||||
|
self._tiled_mlp_dist_impl = TiledMLP
|
||||||
|
|
||||||
down_res = TiledMLP.apply(
|
down_res = self._tiled_mlp_dist_impl.apply(
|
||||||
mlp_forward,
|
mlp_forward,
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@@ -66,6 +80,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
|
|
||||||
mlp_cls.forward = tiled_mlp_forward
|
mlp_cls.forward = tiled_mlp_forward
|
||||||
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
||||||
|
mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
@@ -867,10 +867,16 @@ class GCCallback(TrainerCallback):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
self._gc()
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if self.next_gc_on_begin_step == state.global_step:
|
# pylint: disable=consider-using-in
|
||||||
|
if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:
|
||||||
self._gc()
|
self._gc()
|
||||||
|
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
|
|||||||
@@ -597,7 +597,7 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tiled_mlp_use_original_mlp: bool | None = Field(
|
tiled_mlp_use_original_mlp: bool | None = Field(
|
||||||
default=None,
|
default=True,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
|
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -512,19 +512,6 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_tiled_mlp_deepspeed(cls, data):
|
|
||||||
capabilities = data.get("capabilities")
|
|
||||||
n_gpu = 0
|
|
||||||
if capabilities and capabilities.get("n_gpu", 0) >= 1:
|
|
||||||
n_gpu = capabilities.get("n_gpu", 0)
|
|
||||||
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
|
|
||||||
raise ValueError(
|
|
||||||
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|||||||
Reference in New Issue
Block a user