fix optimizer reset for relora sft (#1414)
* fix optimizer reset * set states to reset for 8bit optimizers and handle quantile runtime error for embeddings * fix relora test to check grad_norm * use flash attn for relora and tweak hyperparams for test * fix messages field for test dataset
This commit is contained in:
@@ -2,4 +2,3 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
tbparse
|
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ pytest
|
|||||||
pytest-xdist
|
pytest-xdist
|
||||||
pytest-retry
|
pytest-retry
|
||||||
pytest-sugar
|
pytest-sugar
|
||||||
|
tbparse
|
||||||
|
|||||||
@@ -46,9 +46,10 @@ def reset_optimizer(
|
|||||||
*,
|
*,
|
||||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||||
optimizer_state_keys: List[str],
|
optimizer_state_keys: List[str],
|
||||||
prune_ratio: float = 0.9,
|
optimizer_magnitude_pruning: float = 0.9,
|
||||||
):
|
):
|
||||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
# pylint:disable=unused-argument
|
||||||
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
||||||
n_zeros = 0
|
n_zeros = 0
|
||||||
n_total = 0
|
n_total = 0
|
||||||
|
|
||||||
@@ -56,16 +57,22 @@ def reset_optimizer(
|
|||||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||||
optimizer_state = optimizer.optim.state
|
optimizer_state = optimizer.optim.state
|
||||||
|
|
||||||
for param in reset_params:
|
for group in optimizer.param_groups:
|
||||||
param_state = optimizer_state[param]
|
for param in group["params"]:
|
||||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
state = optimizer_state[param]
|
||||||
continue
|
for key, value in state.items():
|
||||||
for key in optimizer_state_keys:
|
if key not in optimizer_state_keys:
|
||||||
pruning_fn(
|
continue
|
||||||
param_state[key]
|
if torch.is_tensor(value):
|
||||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
try:
|
||||||
n_total += param_state[key].numel()
|
pruning_fn(value)
|
||||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
n_total += value.numel()
|
||||||
|
n_zeros += torch.sum(value == 0).item()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if "quantile() input tensor is too large" in str(exc):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||||
@@ -129,6 +136,9 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
|
|
||||||
if "adam" in args.optim.lower():
|
if "adam" in args.optim.lower():
|
||||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||||
|
if "8bit" in args.optim.lower():
|
||||||
|
optimizer_state_keys.append("state1")
|
||||||
|
optimizer_state_keys.append("state2")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||||
|
|
||||||
@@ -160,7 +170,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
optimizer,
|
optimizer,
|
||||||
reset_params=lora_params,
|
reset_params=lora_params,
|
||||||
optimizer_state_keys=optimizer_state_keys,
|
optimizer_state_keys=optimizer_state_keys,
|
||||||
prune_ratio=args.relora_prune_ratio,
|
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
|
|||||||
@@ -7,13 +7,15 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import most_recent_subdir, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -29,36 +31,48 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
"sequence_len": 2048,
|
||||||
"sequence_len": 1024,
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"flash_attention": True,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_modules": ["q_proj", "v_proj"],
|
"lora_target_modules": ["q_proj", "v_proj"],
|
||||||
"relora_steps": 25,
|
"relora_steps": 100,
|
||||||
"relora_warmup_steps": 5,
|
"relora_warmup_steps": 20,
|
||||||
"relora_anneal_steps": 5,
|
"relora_anneal_steps": 10,
|
||||||
|
"relora_prune_ratio": 0.9,
|
||||||
"relora_cpu_offload": True,
|
"relora_cpu_offload": True,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"special_tokens": {},
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"chat_template": "chatml",
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mlabonne/FineTome-100k",
|
||||||
"type": "alpaca",
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"warmup_steps": 15,
|
"warmup_steps": 20,
|
||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"max_steps": 51, # at least 2x relora_steps
|
"max_steps": 205, # at least 2x relora_steps
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
@@ -66,4 +80,14 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (
|
||||||
|
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
|
||||||
|
).exists()
|
||||||
|
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
||||||
|
|
||||||
|
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
|
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
|
||||||
|
assert df.value.values[-1] < 0.2, "grad_norm is too high"
|
||||||
|
|||||||
Reference in New Issue
Block a user