Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
2680421081 bump deepspeed to latest 0.14.4 2024-07-13 14:36:18 -04:00
17 changed files with 51 additions and 359 deletions

View File

@@ -57,10 +57,6 @@ jobs:
run: | run: |
pytest --ignore=tests/e2e/ tests/ pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
@@ -103,7 +99,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2 pip install modal jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image
RUN pip install -r requirements-tests.txt RUN pip install pytest
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -2,5 +2,5 @@
set -e set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/ pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -1,2 +1 @@
pytest pytest
pytest-xdist

View File

@@ -5,14 +5,14 @@ transformers==4.42.3
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.1
accelerate==0.32.0 accelerate==0.32.0
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.19.1 datasets==2.19.1
flash-attn==2.6.1 flash-attn==2.5.8
sentencepiece sentencepiece
wandb wandb
einops einops

View File

@@ -80,10 +80,10 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.6.1", "flash-attn==2.5.8",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
@@ -104,11 +104,5 @@ setup(
"galore": [ "galore": [
"galore_torch", "galore_torch",
], ],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
}, },
) )

View File

@@ -13,7 +13,6 @@ from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from multiprocessing import set_start_method
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union from typing import Dict, List, Literal, Optional, Type, Union
@@ -227,12 +226,6 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"}, metadata={"help": "whether to use sequential sampling for curriculum learning"},
) )
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
@dataclass @dataclass
@@ -291,72 +284,26 @@ class AxolotlTrainer(Trainer):
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer(self): def create_optimizer(self):
if ( if self.args.loraplus_lr_ratio is None:
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
loraplus_lr_embedding = getattr( self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
self.args, "loraplus_lr_embedding", None opt_model,
) optimizer_cls,
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init optimizer_kwargs,
opt_model, loraplus_lr_ratio,
optimizer_cls, loraplus_lr_embedding,
optimizer_kwargs, )
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1449,11 +1396,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.optimizer == "optimi_adamw":
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch": if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion from lion_pytorch import Lion
@@ -1771,9 +1713,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
for callback in self.get_post_trainer_create_callbacks(dpo_trainer): for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)
# prevents multiprocessing issues for datasets on multiple GPUs
set_start_method("spawn")
return dpo_trainer return dpo_trainer

View File

@@ -78,33 +78,6 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv) set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
def patch_llama_rms_norm():
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
def replace_llama_attn_with_flash_attn( def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
@@ -131,11 +104,35 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled # skip only if explicitly disabled
if cross_entropy: if cross_entropy:
patch_llama_cross_entropy() try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.warning(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled # skip only if explicitly disabled
if rms_norm: if rms_norm:
patch_llama_rms_norm() try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class FusedAttention(LlamaAttention): class FusedAttention(LlamaAttention):

View File

@@ -10,7 +10,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mixtral", "mixtral",
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
@@ -31,10 +30,6 @@ def patch_for_multipack(model_type, model_name=None):
) )
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3() patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2": elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data

View File

@@ -52,13 +52,6 @@ class TrainDatasetMeta:
def train( def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# enable expandable segments for cuda allocation to improve VRAM usage
# torch_version = torch.__version__.split(".")
# torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
# if torch_major == 2 and torch_minor >= 2:
# if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# load the tokenizer first # load the tokenizer first
LOG.debug( LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",

View File

@@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float] learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]] Union[OptimizerNames, Literal["lion_pytorch"]]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."} default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -1112,31 +1112,6 @@ class AxolotlInputConfig(
raise ValueError("either datasets or pretraining_dataset is required") raise ValueError("either datasets or pretraining_dataset is required")
return data return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""
@@ -1188,18 +1163,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("deepspeed") and data.get("fsdp"): if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.") raise ValueError("deepspeed and fsdp cannot be used together.")
return data return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("num_gpus") > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data

View File

@@ -347,27 +347,6 @@ def load_model(
and cfg.sample_packing and cfg.sample_packing
): ):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch()
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif cfg.is_llama_derived_model: elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block # Modify all llama derived models in one block
@@ -392,12 +371,6 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm, rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True, use_shifted_sparse_attn=True,
) )
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.xformers_attention: elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention, hijack_llama_attention,

View File

@@ -1,87 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from importlib import reload
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA w multipack
"""
@with_temp_dir
def test_lora_packing_fa_cross_entropy(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"flash_attn_cross_entropy": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 32,
"lora_alpha": 16, "lora_alpha": 64,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
"type": "alpaca", "type": "alpaca",
}, },
], ],
"num_epochs": 1, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,

View File

@@ -1,67 +0,0 @@
"""
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomOptimizers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()