chore: lint/formatting
This commit is contained in:
@@ -8,8 +8,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.rrt.modeling import register_rrt_model
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
|
||||
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||
RelaxedRecursiveLlamaConfig,
|
||||
RelaxedRecursiveLlamaForCausalLM,
|
||||
RelaxedRecursiveLlamaModel,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,4 +42,6 @@ def register_rrt_model():
|
||||
|
||||
# Register models
|
||||
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||
AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
||||
AutoModelForCausalLM.register(
|
||||
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
|
||||
)
|
||||
|
||||
@@ -12,15 +12,18 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||
RelaxedRecursiveLlamaConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_layer_number(key):
|
||||
"""Extract layer number from parameter key."""
|
||||
match = re.search(r'layers\.(\d+)\.', key)
|
||||
match = re.search(r"layers\.(\d+)\.", key)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
@@ -32,18 +35,21 @@ def iter_parameter_weights(model_path, device="mps"):
|
||||
:param device: Computing device
|
||||
:return: generator yielding (parameter key, parameter weight, layer index) tuples
|
||||
"""
|
||||
shards = list(model_path.glob('model*.safetensors'))
|
||||
shards = list(model_path.glob("model*.safetensors"))
|
||||
if not shards:
|
||||
raise ValueError(f"No model shards found in {model_path}")
|
||||
|
||||
for shard in tqdm(shards, desc="Processing shards"):
|
||||
with safetensors.safe_open(shard, framework='pt', device=device) as f:
|
||||
for key in f.keys():
|
||||
layer_idx = extract_layer_number(key)
|
||||
weight = f.get_tensor(key)
|
||||
yield key, weight, layer_idx
|
||||
with safetensors.safe_open(shard, framework="pt", device=device) as f:
|
||||
for key in f.keys():
|
||||
layer_idx = extract_layer_number(key)
|
||||
weight = f.get_tensor(key)
|
||||
yield key, weight, layer_idx
|
||||
|
||||
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12):
|
||||
|
||||
def iter_recursive_parameter_weights(
|
||||
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
|
||||
):
|
||||
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
||||
# to avoid precision loss when averaging weights across layers
|
||||
rrt_avg_model_state_dict = {}
|
||||
@@ -52,8 +58,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key),
|
||||
None
|
||||
(module for module in modules_to_recurse if module in key), None
|
||||
)
|
||||
if matched_module_name is None:
|
||||
continue
|
||||
@@ -64,7 +69,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
||||
# setup as storage for suffix with torch.stack
|
||||
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
|
||||
else:
|
||||
rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu())
|
||||
rrt_avg_model_state_dict[suffix].append(
|
||||
weight.to(torch.float32).detach().cpu()
|
||||
)
|
||||
|
||||
for module_name in modules_to_recurse:
|
||||
for recurse_idx in range(recurse_layers):
|
||||
@@ -75,8 +82,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
||||
|
||||
# compute the decomposed lora diff from the weight base to the actual weight for each module
|
||||
|
||||
|
||||
def low_rank_decomposition(
|
||||
weight: torch.Tensor, max_rank: int
|
||||
weight: torch.Tensor, max_rank: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Decompose a 2D matrix into low-rank matrices L and R using SVD.
|
||||
@@ -86,10 +94,10 @@ def low_rank_decomposition(
|
||||
:return: A tuple of tensors (L, R)
|
||||
"""
|
||||
assert (
|
||||
weight.dim() == 2
|
||||
weight.dim() == 2
|
||||
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
|
||||
assert (
|
||||
max_rank >= 1
|
||||
max_rank >= 1
|
||||
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
|
||||
|
||||
dtype = weight.dtype
|
||||
@@ -138,22 +146,31 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||
|
||||
|
||||
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
|
||||
def iter_dora_parameter_weights(
|
||||
model_path,
|
||||
avg_recursive_weights,
|
||||
modules_to_recurse: list[str],
|
||||
alpha,
|
||||
rank,
|
||||
device="mps",
|
||||
recurse_layers=12,
|
||||
):
|
||||
rrt_avg_model_state_dict = {}
|
||||
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key),
|
||||
None
|
||||
(module for module in modules_to_recurse if module in key), None
|
||||
)
|
||||
if matched_module_name is None:
|
||||
if "input_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
|
||||
layernorm_key = (
|
||||
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
|
||||
)
|
||||
yield layernorm_key, weight
|
||||
elif "post_attention_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
@@ -171,19 +188,26 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
|
||||
suffix = f"{layer_idx}.{matched_module_name}"
|
||||
prefix = f"model.layers.{suffix}.weight_base"
|
||||
avg_weight = avg_recursive_weights[prefix]
|
||||
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
||||
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
||||
lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank)
|
||||
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
||||
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
||||
lora_magnitude_key = (
|
||||
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||
)
|
||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
|
||||
weight, avg_weight, alpha, rank
|
||||
)
|
||||
yield lora_a_key, lora_a
|
||||
yield lora_b_key, lora_b
|
||||
yield lora_magnitude_key, lora_magnitude
|
||||
|
||||
|
||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
weights_name = SAFE_WEIGHTS_NAME
|
||||
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
|
||||
)
|
||||
@@ -207,10 +231,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||
|
||||
if (
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
@@ -221,7 +245,9 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
del state_dict[tensor]
|
||||
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
||||
save_file(
|
||||
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
|
||||
)
|
||||
|
||||
del state_dict
|
||||
|
||||
@@ -236,7 +262,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"):
|
||||
|
||||
def convert_llama_to_rrt(
|
||||
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"
|
||||
):
|
||||
modules_to_recurse = [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
@@ -255,7 +284,14 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
||||
f"divisible by the recurse layers ({recurse_layers})"
|
||||
)
|
||||
|
||||
config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha})
|
||||
config = RelaxedRecursiveLlamaConfig.from_dict(
|
||||
{
|
||||
**config.to_dict(),
|
||||
"recurse_layers": recurse_layers,
|
||||
"rank": rank,
|
||||
"alpha": alpha,
|
||||
}
|
||||
)
|
||||
config.save_pretrained(output_dir)
|
||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||
|
||||
@@ -263,13 +299,23 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
||||
rrt_model_state_dict = {}
|
||||
|
||||
logger.info(f"Calculating average recursive weights...")
|
||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
|
||||
for key, weight in iter_recursive_parameter_weights(
|
||||
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
|
||||
):
|
||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
logger.info(f"Calculating decomposed lora diff...")
|
||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||
rrt_lora_state_dict = {}
|
||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
|
||||
for key, weight in iter_dora_parameter_weights(
|
||||
model_path,
|
||||
rrt_model_state_dict,
|
||||
modules_to_recurse,
|
||||
alpha=32,
|
||||
rank=rank,
|
||||
device=device,
|
||||
recurse_layers=recurse_layers,
|
||||
):
|
||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
# combine state dicts into a single state_dict
|
||||
@@ -287,4 +333,11 @@ if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)
|
||||
convert_llama_to_rrt(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"/tmp/rrt_model",
|
||||
recurse_layers=4,
|
||||
rank=256,
|
||||
alpha=512,
|
||||
device=device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user