support seperate lr for embeddings, similar to loraplus (#1910) [skip ci]
* support seperate lr for embeddings, similar to loraplus * add test case for train w lr embedding scale * use kwarg for optimizer * make sure to handle the optimizer creation * make sure to handle for embedding_lr too * use smollm for e2e, check for embeddings lr first before wdecay
This commit is contained in:
@@ -220,6 +220,14 @@ class AxolotlTrainingMixins:
|
|||||||
default=1e-6,
|
default=1e-6,
|
||||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||||
)
|
)
|
||||||
|
embedding_lr_scale: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
embedding_lr: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
qlora: bool = field(
|
qlora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "whether this is a qlora training"},
|
metadata={"help": "whether this is a qlora training"},
|
||||||
@@ -386,7 +394,7 @@ class SchedulerMixin(Trainer):
|
|||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
@@ -435,6 +443,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in [
|
not in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
@@ -449,30 +459,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
optimizer_grouped_parameters = [
|
params = {
|
||||||
{
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"params": [
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
p
|
"no_weight_decay": {},
|
||||||
for n, p in opt_model.named_parameters()
|
}
|
||||||
if (n in decay_parameters and p.requires_grad)
|
|
||||||
],
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in opt_model.named_parameters()
|
|
||||||
if (n not in decay_parameters and p.requires_grad)
|
|
||||||
],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
self.args,
|
self.args,
|
||||||
opt_model,
|
opt_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
loraplus_lr_embedding = getattr(
|
loraplus_lr_embedding = getattr(
|
||||||
@@ -485,6 +524,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
**optimizer_kwargs,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.args.embedding_lr_scale is not None
|
||||||
|
or self.args.embedding_lr is not None
|
||||||
|
):
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -1571,6 +1617,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"loraplus_lr_embedding"
|
"loraplus_lr_embedding"
|
||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
|
|||||||
@@ -431,6 +431,8 @@ class HyperparametersConfig(BaseModel):
|
|||||||
group_by_length: Optional[bool] = None
|
group_by_length: Optional[bool] = None
|
||||||
|
|
||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
|
embedding_lr: Optional[float] = None
|
||||||
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[
|
Union[
|
||||||
|
|||||||
121
tests/e2e/test_embeddings_lr.py
Normal file
121
tests/e2e/test_embeddings_lr.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama pretrain
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import most_recent_subdir, with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingsLrScale(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for embedding_lr*
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_train_w_embedding_lr_scale(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"embedding_lr_scale": 0.5,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
|
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
|
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||||
|
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_train_w_embedding_lr(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"embedding_lr": 0.000005,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
|
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
|
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||||
|
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||||
Reference in New Issue
Block a user