Compare commits
3 Commits
6dc0f4dac6
...
shampoo-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1b4030cdd | ||
|
|
035e9f9dd7 | ||
|
|
02ce520b7e |
@@ -562,7 +562,8 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
|
||||
@@ -35,3 +35,7 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
||||
pip3 install flash-attn==2.6.3; \
|
||||
fi
|
||||
|
||||
@@ -183,8 +183,6 @@ test_datasets:
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
|
||||
@@ -9,7 +9,7 @@ strict: false
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
chat_template: deepseek_v2
|
||||
|
||||
@@ -4,7 +4,7 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
strict: false
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.0
|
||||
transformers==4.46.1
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.0.1
|
||||
accelerate==1.1.0
|
||||
datasets==3.0.1
|
||||
deepspeed==0.15.3
|
||||
pydantic==2.6.3
|
||||
@@ -34,7 +34,7 @@ tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.3.0
|
||||
liger-kernel==0.4.0
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl @ git++https://github.com/huggingface/trl.git@5e90682836969310e16ed8aa711dd429f85863b7
|
||||
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
|
||||
@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
@@ -1890,18 +1890,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1910,13 +1909,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
elif self.cfg.rl == "orpo":
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "kto":
|
||||
if self.cfg.rl == "kto":
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
@@ -1926,32 +1925,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_kwargs["max_target_length"] = None
|
||||
if self.cfg.max_prompt_len is not None:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
training_args_kwargs["precompute_ref_log_probs"] = self.cfg.precompute_ref_log_probs
|
||||
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1971,16 +1949,27 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
def build(self, total_num_steps):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||
It is designed to be performant, correct, and light-weight.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
"""
|
||||
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.model_config_type == "llama":
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "mistral":
|
||||
from liger_kernel.transformers.model.mistral import (
|
||||
lce_forward as mistral_lce_forward,
|
||||
)
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma":
|
||||
from liger_kernel.transformers.model.gemma import (
|
||||
lce_forward as gemma_lce_forward,
|
||||
)
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma.GemmaRMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
kwargs = {}
|
||||
if "rope" in liger_fn_sig.parameters:
|
||||
kwargs["rope"] = cfg.liger_rope
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||
if "geglu" in liger_fn_sig.parameters:
|
||||
kwargs["geglu"] = cfg.liger_glu_activation
|
||||
elif "swiglu" in liger_fn_sig.parameters:
|
||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
||||
|
||||
apply_liger_fn(**kwargs)
|
||||
elif cfg.model_config_type == "jamba":
|
||||
from transformers.models.jamba import modeling_jamba
|
||||
|
||||
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "qwen2":
|
||||
from liger_kernel.transformers.model.qwen2 import (
|
||||
lce_forward as qwen2_lce_forward,
|
||||
)
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "deepseek_v2":
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma2":
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
logging.warning(
|
||||
"Fused linear cross entropy is not supported for Gemma 2."
|
||||
)
|
||||
|
||||
elif cfg.model_config_type == "phi3":
|
||||
from liger_kernel.transformers.model.phi3 import (
|
||||
lce_forward as phi3_lce_forward,
|
||||
)
|
||||
from transformers.models.phi3 import modeling_phi3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
||||
|
||||
@@ -15,9 +15,12 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||
|
||||
|
||||
class LigerArgs(BaseModel):
|
||||
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
if data.get("liger_swiglu") is not None:
|
||||
if data.get("liger_glu_activation") is not None:
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||
)
|
||||
|
||||
LOG.warning(
|
||||
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||
"Please use 'liger_glu_activation' instead."
|
||||
)
|
||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||
return data
|
||||
|
||||
@@ -588,9 +588,6 @@ class AxolotlInputConfig(
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
|
||||
0
src/axolotl/utils/optimizers/__init__.py
Normal file
0
src/axolotl/utils/optimizers/__init__.py
Normal file
250
src/axolotl/utils/optimizers/shampoo.py
Normal file
250
src/axolotl/utils/optimizers/shampoo.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.optim import Optimizer
|
||||
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
|
||||
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
|
||||
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
|
||||
|
||||
|
||||
class _ShampooBase(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size,
|
||||
quantization_bits,
|
||||
optimizer_state_class,
|
||||
):
|
||||
if lr <= 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if eps < 0.0:
|
||||
raise ValueError(f"Invalid eps value: {eps}")
|
||||
if update_freq < 1:
|
||||
raise ValueError(f"Invalid update_freq value: {update_freq}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
update_freq=update_freq,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
self.block_size = block_size
|
||||
self.quantization_bits = quantization_bits
|
||||
self.optimizer_state_class = optimizer_state_class
|
||||
|
||||
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state["momentum_buffer"] = self._new_buffer(grad, True)
|
||||
state["preconds"] = []
|
||||
state["inv_preconds"] = []
|
||||
for dim in grad.size():
|
||||
state["preconds"].append(
|
||||
self.optimizer_state_class.zeros(
|
||||
(dim, dim),
|
||||
signed=False,
|
||||
block_size=self.block_size,
|
||||
device=grad.device,
|
||||
)
|
||||
)
|
||||
state["inv_preconds"].append(
|
||||
torch.zeros((dim, dim), device=grad.device)
|
||||
)
|
||||
|
||||
state["step"] += 1
|
||||
beta = group["momentum"]
|
||||
weight_decay = group["weight_decay"]
|
||||
lr = group["lr"]
|
||||
eps = group["eps"]
|
||||
update_freq = group["update_freq"]
|
||||
|
||||
# Apply momentum
|
||||
if beta > 0:
|
||||
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
|
||||
grad = state["momentum_buffer"]
|
||||
|
||||
# Apply weight decay
|
||||
if weight_decay > 0:
|
||||
grad = grad.add(p.data, alpha=weight_decay)
|
||||
|
||||
# Preconditioning
|
||||
order = grad.ndimension()
|
||||
original_size = grad.size()
|
||||
for dim_id, dim in enumerate(grad.size()):
|
||||
precond = state["preconds"][dim_id]
|
||||
inv_precond = state["inv_preconds"][dim_id]
|
||||
|
||||
# Reshape grad
|
||||
grad = grad.transpose(0, dim_id).contiguous()
|
||||
transposed_size = grad.size()
|
||||
grad = grad.view(dim, -1)
|
||||
|
||||
grad_t = grad.t()
|
||||
|
||||
# Update preconditioner
|
||||
precond_fp32 = precond.dequantize()
|
||||
precond_update = grad @ grad_t
|
||||
precond_fp32.add_(precond_update)
|
||||
|
||||
# Quantize preconditioner back
|
||||
precond.copy_(precond_fp32)
|
||||
|
||||
# Update inverse preconditioner
|
||||
if state["step"] % update_freq == 0:
|
||||
inv_precond.copy_(
|
||||
self._compute_inv_precond(precond_fp32, eps, order)
|
||||
)
|
||||
|
||||
# Precondition grad
|
||||
if dim_id == order - 1:
|
||||
# Last dimension
|
||||
grad = grad_t @ inv_precond
|
||||
grad = grad.view(original_size)
|
||||
else:
|
||||
grad = inv_precond @ grad
|
||||
grad = grad.view(transposed_size)
|
||||
|
||||
# Update parameter
|
||||
p.data.add_(grad, alpha=-lr)
|
||||
|
||||
return loss
|
||||
|
||||
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
|
||||
# Add eps for numerical stability
|
||||
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
|
||||
|
||||
# Compute matrix power
|
||||
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
|
||||
|
||||
return inv_precond
|
||||
|
||||
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
|
||||
# Compute matrix power using SVD
|
||||
u, s, v = torch.svd(matrix)
|
||||
s_pow = s.pow(power)
|
||||
return u @ torch.diag(s_pow) @ v.t()
|
||||
|
||||
# bring your own function to create zero-filled subclass
|
||||
@staticmethod
|
||||
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
|
||||
raise NotImplementedError
|
||||
|
||||
# follow bitsandbytes, only quantize tensors >= 4096 values
|
||||
# also wrap subclass in DTensor when needed
|
||||
def _new_buffer(self, p: Tensor, signed: bool):
|
||||
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
|
||||
if isinstance(p, DTensor):
|
||||
out = DTensor.from_local(
|
||||
local_tensor=self._subclass_zeros(
|
||||
p.to_local(), signed, self.block_size
|
||||
),
|
||||
device_mesh=p.device_mesh,
|
||||
placements=p.placements,
|
||||
run_check=False,
|
||||
)
|
||||
else:
|
||||
out = self._subclass_zeros(p, signed, self.block_size)
|
||||
else:
|
||||
out = torch.zeros_like(p)
|
||||
return out
|
||||
|
||||
|
||||
class Shampoo8bit(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=256,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=8,
|
||||
optimizer_state_class=OptimState8bit,
|
||||
)
|
||||
|
||||
|
||||
class Shampoo4bit(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=128,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=4,
|
||||
optimizer_state_class=OptimState4bit,
|
||||
)
|
||||
|
||||
|
||||
class ShampooFp8(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=256,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=8, # FP8 uses 8 bits
|
||||
optimizer_state_class=OptimStateFp8,
|
||||
)
|
||||
59
test.yml
59
test.yml
@@ -1,59 +0,0 @@
|
||||
base_model: JackFram/llama-68m
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: arcee-ai/distilabel-intel-orca-dpo-pairs-binarized
|
||||
type: chatml.ultra
|
||||
split: train
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
rl: dpo
|
||||
dpo_use_weighting: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 2
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
43
test2.yml
43
test2.yml
@@ -1,43 +0,0 @@
|
||||
base_model: JackFram/llama-68m
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
datasets:
|
||||
- path: arcee-ai/distilabel-intel-orca-dpo-pairs-binarized
|
||||
type: chatml.ultra
|
||||
split: train
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 1024
|
||||
|
||||
adapter: lora
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
lora_target_linear: true
|
||||
|
||||
rl: dpo
|
||||
dpo_use_weighting: true
|
||||
|
||||
wandb_project: check_dpotrainer
|
||||
wandb_entity: axolotl-ai
|
||||
wandb_watch:
|
||||
wandb_name: baseline/dpo_base/dpo_use_weighting
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
num_epochs: 1
|
||||
micro_batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 0.00001
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
max_steps": 20
|
||||
save_steps: 10
|
||||
warmup_steps: 5
|
||||
gradient_checkpointing: True
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
#special_tokens:
|
||||
# pad_token: <|end_of_text|>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -115,51 +115,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_use_weighting(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {},
|
||||
"rl": "dpo",
|
||||
"dpo_use_weighting": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
"type": "chatml.ultra",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
|
||||
0
tests/integrations/__init__.py
Normal file
0
tests/integrations/__init__.py
Normal file
80
tests/integrations/liger.py
Normal file
80
tests/integrations/liger.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
config validation tests for swiglu args
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_base_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module for liger
|
||||
"""
|
||||
|
||||
def test_deprecated_swiglu(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
updated_cfg = validate_config(test_cfg)
|
||||
assert (
|
||||
"The 'liger_swiglu' argument is deprecated"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
assert updated_cfg.liger_swiglu is None
|
||||
assert updated_cfg.liger_glu_activations is False
|
||||
|
||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
"liger_glu_activations": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
||||
):
|
||||
validate_config(test_cfg)
|
||||
@@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"""Verify that processing data from the hub works with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
|
||||
# make sure prepared_path is empty
|
||||
shutil.rmtree(prepared_path, ignore_errors=True)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
|
||||
Reference in New Issue
Block a user