Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
e86dd76154 attempt to set start method to spwan to prevent cuda issues for DPO 2024-07-17 09:29:15 -04:00
Wing Lian
5f58555bd0 support for llama multipack using updated code/patches (#1754)
* support for llama multipack using updated code/patches

* also support unsloth patches

* incorrect arg

* add config validation for unsloth

* add missing return to validation

* add another missing return to validation
2024-07-16 17:36:29 -04:00
Wing Lian
cfc533a7f7 torch compile and cuda alloc improvements (#1755)
* enable experimental expandable_segments

* hf trainer seems to be missing torch compile

* disable PYTORCH_CUDA_ALLOC_CONF to see if that fixes cicd
2024-07-16 16:00:23 -04:00
Wing Lian
e1725aef2b update modal package and don't cache pip install (#1757)
* update modal package and cleanup pip cache

* more verbosity on the test
2024-07-16 14:45:38 -04:00
Wing Lian
78e12f8ca5 add basic support for the optimi adamw optimizer (#1727)
* add support for optimi_adamw optimizer w kahan summation

* pydantic validator for optimi_adamw

* workaround for setting optimizer for fsdp

* make sure to install optimizer packages

* make sure to have parity for model parameters passed to optimizer

* add smoke test for optimi_adamw optimizer

* don't use foreach optimi by default
2024-07-14 19:12:57 -04:00
13 changed files with 260 additions and 41 deletions

View File

@@ -57,6 +57,10 @@ jobs:
run: | run: |
pytest --ignore=tests/e2e/ tests/ pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
@@ -99,7 +103,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -2,5 +2,5 @@
set -e set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/ pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/ pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -104,5 +104,11 @@ setup(
"galore": [ "galore": [
"galore_torch", "galore_torch",
], ],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
}, },
) )

View File

@@ -13,6 +13,7 @@ from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from multiprocessing import set_start_method
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union from typing import Dict, List, Literal, Optional, Type, Union
@@ -226,6 +227,12 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"}, metadata={"help": "whether to use sequential sampling for curriculum learning"},
) )
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
@dataclass @dataclass
@@ -284,26 +291,72 @@ class AxolotlTrainer(Trainer):
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer(self): def create_optimizer(self):
if self.args.loraplus_lr_ratio is None: if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) if self.args.loraplus_lr_ratio is not None:
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init loraplus_lr_embedding = getattr(
opt_model, self.args, "loraplus_lr_embedding", None
optimizer_cls, )
optimizer_kwargs, self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
loraplus_lr_ratio, opt_model,
loraplus_lr_embedding, optimizer_cls,
) optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1396,6 +1449,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.optimizer == "optimi_adamw":
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch": if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion from lion_pytorch import Lion
@@ -1713,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
for callback in self.get_post_trainer_create_callbacks(dpo_trainer): for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)
# prevents multiprocessing issues for datasets on multiple GPUs
set_start_method("spawn")
return dpo_trainer return dpo_trainer

View File

@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv) set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
def patch_llama_rms_norm():
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
def replace_llama_attn_with_flash_attn( def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
@@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled # skip only if explicitly disabled
if cross_entropy: if cross_entropy:
from flash_attn.losses.cross_entropy import CrossEntropyLoss patch_llama_cross_entropy()
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
# skip only if explicitly disabled # skip only if explicitly disabled
if rms_norm: if rms_norm:
try: patch_llama_rms_norm()
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class FusedAttention(LlamaAttention): class FusedAttention(LlamaAttention):

View File

@@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mixtral", "mixtral",
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
@@ -30,6 +31,10 @@ def patch_for_multipack(model_type, model_name=None):
) )
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3() patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2": elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data

View File

@@ -52,6 +52,13 @@ class TrainDatasetMeta:
def train( def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# enable expandable segments for cuda allocation to improve VRAM usage
# torch_version = torch.__version__.split(".")
# torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
# if torch_major == 2 and torch_minor >= 2:
# if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# load the tokenizer first # load the tokenizer first
LOG.debug( LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",

View File

@@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float] learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]] Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."} default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -1112,6 +1112,31 @@ class AxolotlInputConfig(
raise ValueError("either datasets or pretraining_dataset is required") raise ValueError("either datasets or pretraining_dataset is required")
return data return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""
@@ -1163,3 +1188,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("deepspeed") and data.get("fsdp"): if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.") raise ValueError("deepspeed and fsdp cannot be used together.")
return data return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("num_gpus") > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data

View File

@@ -347,6 +347,27 @@ def load_model(
and cfg.sample_packing and cfg.sample_packing
): ):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch()
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif cfg.is_llama_derived_model: elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block # Modify all llama derived models in one block

View File

@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 32, "lora_r": 8,
"lora_alpha": 64, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
"type": "alpaca", "type": "alpaca",
}, },
], ],
"num_epochs": 2, "num_epochs": 1,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,

View File

@@ -0,0 +1,67 @@
"""
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomOptimizers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()