Compare commits
5 Commits
fa-261
...
dpo-spawn-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e86dd76154 | ||
|
|
5f58555bd0 | ||
|
|
cfc533a7f7 | ||
|
|
e1725aef2b | ||
|
|
78e12f8ca5 |
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -57,6 +57,10 @@ jobs:
|
||||
run: |
|
||||
pytest --ignore=tests/e2e/ tests/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
@@ -99,7 +103,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal jinja2
|
||||
pip install modal==0.63.64 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN pip install causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
set -e
|
||||
|
||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
|
||||
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN pip install causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
6
setup.py
6
setup.py
@@ -104,5 +104,11 @@ setup(
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"lion-pytorch==0.1.2",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from multiprocessing import set_start_method
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
@@ -226,6 +227,12 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_optimizer: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -284,26 +291,72 @@ class AxolotlTrainer(Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.alternate_optimizer != "optimi_adamw"
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
loraplus_lr_ratio,
|
||||
loraplus_lr_embedding,
|
||||
)
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", None
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
loraplus_lr_ratio,
|
||||
loraplus_lr_embedding,
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW(
|
||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -1396,6 +1449,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "optimi_adamw":
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||
|
||||
if self.cfg.optimizer == "lion_pytorch":
|
||||
from lion_pytorch import Lion
|
||||
|
||||
@@ -1713,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
# prevents multiprocessing issues for datasets on multiple GPUs
|
||||
set_start_method("spawn")
|
||||
|
||||
return dpo_trainer
|
||||
|
||||
|
||||
|
||||
@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def patch_llama_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
|
||||
def patch_llama_rms_norm():
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
@@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn(
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
patch_llama_cross_entropy()
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
patch_llama_rms_norm()
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
|
||||
@@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
@@ -30,6 +31,10 @@ def patch_for_multipack(model_type, model_name=None):
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
|
||||
@@ -52,6 +52,13 @@ class TrainDatasetMeta:
|
||||
def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||
# torch_version = torch.__version__.split(".")
|
||||
# torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||
# if torch_major == 2 and torch_minor >= 2:
|
||||
# if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
# load the tokenizer first
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
|
||||
@@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel):
|
||||
learning_rate: Union[str, float]
|
||||
weight_decay: Optional[float] = 0.0
|
||||
optimizer: Optional[
|
||||
Union[OptimizerNames, Literal["lion_pytorch"]]
|
||||
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||
@@ -1112,6 +1112,31 @@ class AxolotlInputConfig(
|
||||
raise ValueError("either datasets or pretraining_dataset is required")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_xentropy_patch_conflicts(cls, data):
|
||||
if data.get("flash_attn_cross_entropy") and data.get(
|
||||
"unsloth_cross_entropy_loss"
|
||||
):
|
||||
raise ValueError(
|
||||
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_qlora_unsloth(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
@@ -1163,3 +1188,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("deepspeed") and data.get("fsdp"):
|
||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_unsloth(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
if capabilities and capabilities.get("num_gpus") > 1:
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -347,6 +347,27 @@ def load_model(
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||
|
||||
if cfg.is_llama_derived_model:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_cross_entropy:
|
||||
patch_llama_cross_entropy()
|
||||
if cfg.flash_attn_rms_norm:
|
||||
patch_llama_rms_norm()
|
||||
if cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import (
|
||||
integrate_cross_entropy_loss_patch,
|
||||
)
|
||||
|
||||
integrate_cross_entropy_loss_patch()
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
|
||||
@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
|
||||
67
tests/e2e/test_optimizers.py
Normal file
67
tests/e2e/test_optimizers.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
E2E tests for custom optimizers using Llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestCustomOptimizers(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_optimi_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "optimi_adamw",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
Reference in New Issue
Block a user