Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
105c65390e add q-galore optimizer 2024-07-14 19:28:13 -04:00
4 changed files with 56 additions and 3 deletions

View File

@@ -108,6 +108,7 @@ setup(
"galore_torch", "galore_torch",
"lion-pytorch==0.1.2", "lion-pytorch==0.1.2",
"lomo-optim==0.1.1", "lomo-optim==0.1.1",
"q-galore-torch==1.0",
"torch-optimi==0.2.1", "torch-optimi==0.2.1",
], ],
}, },

View File

@@ -293,7 +293,8 @@ class AxolotlTrainer(Trainer):
def create_optimizer(self): def create_optimizer(self):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw" and self.args.alternate_optimizer
not in ["optimi_adamw", "q_galore_adamw8bit"]
): ):
return super().create_optimizer() return super().create_optimizer()
@@ -344,6 +345,12 @@ class AxolotlTrainer(Trainer):
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
) )
) )
elif self.args.alternate_optimizer == "q_galore_adamw8bit":
from q_galore_torch import QGaLoreAdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
QGaLoreAdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1436,7 +1443,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.optimizer == "optimi_adamw": if self.cfg.optimizer in ["optimi_adamw", "q_galore_adamw8bit"]:
# Set default so transformers doesn't throw # Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf" training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer

View File

@@ -341,7 +341,10 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float] learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]] Union[
OptimizerNames,
Literal["lion_pytorch", "optimi_adamw", "q_galore_adamw8bit"],
]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."} default=None, metadata={"help": "Optional arguments to supply to optimizer."}

View File

@@ -65,3 +65,45 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_q_galore_adamw8bit(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "q_galore_adamw8bit",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()