bump transformers and set roundup_power2_divisions for more VRAM improvements, low bit ao optimizers (#1769)
* bump transformers and set roundup_power2_divisions for more VRAM improvements * support for low bit optimizers from torch ao * fix check for alternate optimizers and use nous models on hf for llama3 * add missing check for ao_adamw_fp8 * fix check when using custom optimizers w adamw
This commit is contained in:
19
docs/torchao.qmd
Normal file
19
docs/torchao.qmd
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
---
|
||||||
|
title: "PyTorch ao"
|
||||||
|
description: "Custom data types and layouts for training and inference"
|
||||||
|
---
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Stable Release from the PyTorch index
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Nightly release
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
|
||||||
|
```
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
base_model: NousResearch/Meta-Llama-3-8B-Instruct
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.42.3
|
transformers==4.42.4
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.32.0
|
accelerate==0.32.0
|
||||||
|
|||||||
@@ -305,7 +305,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer != "optimi_adamw"
|
and self.args.alternate_optimizer
|
||||||
|
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -356,6 +357,24 @@ class AxolotlTrainer(Trainer):
|
|||||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1452,7 +1471,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.optimizer == "optimi_adamw":
|
if self.cfg.optimizer in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||||
|
|||||||
@@ -57,7 +57,9 @@ def train(
|
|||||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
if torch_major == 2 and torch_minor >= 2:
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ[
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF"
|
||||||
|
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
|
|||||||
@@ -346,7 +346,16 @@ class HyperparametersConfig(BaseModel):
|
|||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
|
Union[
|
||||||
|
OptimizerNames,
|
||||||
|
Literal[
|
||||||
|
"lion_pytorch",
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
],
|
||||||
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||||
@@ -850,7 +859,7 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_adamw_optimizer_params(self):
|
def check_adamw_optimizer_params(self):
|
||||||
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
||||||
not self.optimizer or "adamw" not in self.optimizer.value
|
not self.optimizer or "adamw" not in str(self.optimizer).lower()
|
||||||
):
|
):
|
||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
|
|||||||
Reference in New Issue
Block a user