chore: lint/formatting

This commit is contained in:
Wing Lian
2025-01-20 12:20:20 -05:00
parent de771fcb05
commit 474ba1a1b8
2 changed files with 96 additions and 38 deletions

View File

@@ -8,8 +8,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
LOG = logging.getLogger(__name__)
@@ -39,4 +42,6 @@ def register_rrt_model():
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -12,15 +12,18 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig
from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
)
logger = logging.getLogger(__name__)
def extract_layer_number(key):
"""Extract layer number from parameter key."""
match = re.search(r'layers\.(\d+)\.', key)
match = re.search(r"layers\.(\d+)\.", key)
return int(match.group(1)) if match else None
@@ -32,18 +35,21 @@ def iter_parameter_weights(model_path, device="mps"):
:param device: Computing device
:return: generator yielding (parameter key, parameter weight, layer index) tuples
"""
shards = list(model_path.glob('model*.safetensors'))
shards = list(model_path.glob("model*.safetensors"))
if not shards:
raise ValueError(f"No model shards found in {model_path}")
for shard in tqdm(shards, desc="Processing shards"):
with safetensors.safe_open(shard, framework='pt', device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
with safetensors.safe_open(shard, framework="pt", device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12):
def iter_recursive_parameter_weights(
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict = {}
@@ -52,8 +58,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
None
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
continue
@@ -64,7 +69,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
# setup as storage for suffix with torch.stack
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
else:
rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu())
rrt_avg_model_state_dict[suffix].append(
weight.to(torch.float32).detach().cpu()
)
for module_name in modules_to_recurse:
for recurse_idx in range(recurse_layers):
@@ -75,8 +82,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
# compute the decomposed lora diff from the weight base to the actual weight for each module
def low_rank_decomposition(
weight: torch.Tensor, max_rank: int
weight: torch.Tensor, max_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decompose a 2D matrix into low-rank matrices L and R using SVD.
@@ -86,10 +94,10 @@ def low_rank_decomposition(
:return: A tuple of tensors (L, R)
"""
assert (
weight.dim() == 2
weight.dim() == 2
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
assert (
max_rank >= 1
max_rank >= 1
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
dtype = weight.dtype
@@ -138,22 +146,31 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
def iter_dora_parameter_weights(
model_path,
avg_recursive_weights,
modules_to_recurse: list[str],
alpha,
rank,
device="mps",
recurse_layers=12,
):
rrt_avg_model_state_dict = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
None
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
if "input_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
layernorm_key = (
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
)
yield layernorm_key, weight
elif "post_attention_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
@@ -171,19 +188,26 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
suffix = f"{layer_idx}.{matched_module_name}"
prefix = f"model.layers.{suffix}.weight_base"
avg_weight = avg_recursive_weights[prefix]
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank)
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
lora_magnitude_key = (
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
)
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
weight, avg_weight, alpha, rank
)
yield lora_a_key, lora_a
yield lora_b_key, lora_b
yield lora_magnitude_key, lora_magnitude
def save_state_dict_to_safetensors(state_dict, save_directory):
os.makedirs(save_directory, exist_ok=True)
weights_name = SAFE_WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
)
@@ -207,10 +231,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
@@ -221,7 +245,9 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
shard[tensor] = state_dict[tensor].contiguous()
del state_dict[tensor]
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
save_file(
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
)
del state_dict
@@ -236,7 +262,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"):
def convert_llama_to_rrt(
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"
):
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
@@ -255,7 +284,14 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
f"divisible by the recurse layers ({recurse_layers})"
)
config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha})
config = RelaxedRecursiveLlamaConfig.from_dict(
{
**config.to_dict(),
"recurse_layers": recurse_layers,
"rank": rank,
"alpha": alpha,
}
)
config.save_pretrained(output_dir)
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
@@ -263,13 +299,23 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
rrt_model_state_dict = {}
logger.info(f"Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
for key, weight in iter_recursive_parameter_weights(
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
logger.info(f"Calculating decomposed lora diff...")
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
for key, weight in iter_dora_parameter_weights(
model_path,
rrt_model_state_dict,
modules_to_recurse,
alpha=32,
rank=rank,
device=device,
recurse_layers=recurse_layers,
):
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# combine state dicts into a single state_dict
@@ -287,4 +333,11 @@ if __name__ == "__main__":
device = "cuda"
else:
device = "cpu"
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)
convert_llama_to_rrt(
"meta-llama/Llama-3.2-1B",
"/tmp/rrt_model",
recurse_layers=4,
rank=256,
alpha=512,
device=device,
)