add q-galore optimizer

This commit is contained in:
Wing Lian
2024-07-14 19:28:13 -04:00
parent 78e12f8ca5
commit 105c65390e
4 changed files with 56 additions and 3 deletions

View File

@@ -108,6 +108,7 @@ setup(
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"q-galore-torch==1.0",
"torch-optimi==0.2.1",
],
},

View File

@@ -293,7 +293,8 @@ class AxolotlTrainer(Trainer):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
and self.args.alternate_optimizer
not in ["optimi_adamw", "q_galore_adamw8bit"]
):
return super().create_optimizer()
@@ -344,6 +345,12 @@ class AxolotlTrainer(Trainer):
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "q_galore_adamw8bit":
from q_galore_torch import QGaLoreAdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
QGaLoreAdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1436,7 +1443,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer == "optimi_adamw":
if self.cfg.optimizer in ["optimi_adamw", "q_galore_adamw8bit"]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer

View File

@@ -341,7 +341,10 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
Union[
OptimizerNames,
Literal["lion_pytorch", "optimi_adamw", "q_galore_adamw8bit"],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}

View File

@@ -65,3 +65,45 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_q_galore_adamw8bit(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "q_galore_adamw8bit",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()