Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
20d0427ac9 update llama3 example base models to use nous 2024-07-15 17:19:00 -04:00
Wing Lian
78e12f8ca5 add basic support for the optimi adamw optimizer (#1727)
* add support for optimi_adamw optimizer w kahan summation

* pydantic validator for optimi_adamw

* workaround for setting optimizer for fsdp

* make sure to install optimizer packages

* make sure to have parity for model parameters passed to optimizer

* add smoke test for optimi_adamw optimizer

* don't use foreach optimi by default
2024-07-14 19:12:57 -04:00
11 changed files with 141 additions and 23 deletions

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,4 @@
base_model: casperhansen/llama-3-70b-fp16 base_model: NousResearch/Meta-Llama-3-70B
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Meta-Llama-3-8B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -7,7 +7,7 @@ load_in_4bit: true
strict: false strict: false
datasets: datasets:
- path: aaditya/alpaca_subset_1 - path: tatsu-lab/alpaca
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0 val_set_size: 0

View File

@@ -104,5 +104,11 @@ setup(
"galore": [ "galore": [
"galore_torch", "galore_torch",
], ],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
}, },
) )

View File

@@ -226,6 +226,12 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"}, metadata={"help": "whether to use sequential sampling for curriculum learning"},
) )
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
@dataclass @dataclass
@@ -285,25 +291,59 @@ class AxolotlTrainer(Trainer):
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def create_optimizer(self): def create_optimizer(self):
if self.args.loraplus_lr_ratio is None: if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) if self.args.loraplus_lr_ratio is not None:
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init loraplus_lr_embedding = getattr(
opt_model, self.args, "loraplus_lr_embedding", None
optimizer_cls, )
optimizer_kwargs, self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
loraplus_lr_ratio, opt_model,
loraplus_lr_embedding, optimizer_cls,
) optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1396,6 +1436,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.optimizer == "optimi_adamw":
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch": if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion from lion_pytorch import Lion

View File

@@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float] learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]] Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."} default=None, metadata={"help": "Optional arguments to supply to optimizer."}

View File

@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 32, "lora_r": 8,
"lora_alpha": 64, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
"type": "alpaca", "type": "alpaca",
}, },
], ],
"num_epochs": 2, "num_epochs": 1,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,

View File

@@ -0,0 +1,67 @@
"""
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomOptimizers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()