bump transformers and set roundup_power2_divisions for more VRAM improvements, low bit ao optimizers (#1769)

* bump transformers and set roundup_power2_divisions for more VRAM improvements

* support for low bit optimizers from torch ao

* fix check for alternate optimizers and use nous models on hf for llama3

* add missing check for ao_adamw_fp8

* fix check when using custom optimizers w adamw
This commit is contained in:
Wing Lian
2024-07-19 00:47:07 -04:00
committed by GitHub
parent 7830fe04b5
commit e4063d60a7
9 changed files with 64 additions and 10 deletions

19
docs/torchao.qmd Normal file
View File

@@ -0,0 +1,19 @@
---
title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---
### Installation
Stable Release from the PyTorch index
```bash
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
```
Nightly release
```bash
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
```

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
base_model: NousResearch/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.42.3
transformers==4.42.4
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0

View File

@@ -305,7 +305,8 @@ class AxolotlTrainer(Trainer):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer()
@@ -356,6 +357,24 @@ class AxolotlTrainer(Trainer):
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1452,7 +1471,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer == "optimi_adamw":
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer

View File

@@ -57,7 +57,9 @@ def train(
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
# load the tokenizer first
LOG.debug(

View File

@@ -346,7 +346,16 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -850,7 +859,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in self.optimizer.value
not self.optimizer or "adamw" not in str(self.optimizer).lower()
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self