chore: lint/formatting
This commit is contained in:
@@ -8,8 +8,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
|||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.integrations.rrt.modeling import register_rrt_model
|
from axolotl.integrations.rrt.modeling import register_rrt_model
|
||||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
|
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||||
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
|
RelaxedRecursiveLlamaConfig,
|
||||||
|
RelaxedRecursiveLlamaForCausalLM,
|
||||||
|
RelaxedRecursiveLlamaModel,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -39,4 +42,6 @@ def register_rrt_model():
|
|||||||
|
|
||||||
# Register models
|
# Register models
|
||||||
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||||
AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
AutoModelForCausalLM.register(
|
||||||
|
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
|
||||||
|
)
|
||||||
|
|||||||
@@ -12,15 +12,18 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard
|
|||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
|
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||||
|
RelaxedRecursiveLlamaConfig,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def extract_layer_number(key):
|
def extract_layer_number(key):
|
||||||
"""Extract layer number from parameter key."""
|
"""Extract layer number from parameter key."""
|
||||||
match = re.search(r'layers\.(\d+)\.', key)
|
match = re.search(r"layers\.(\d+)\.", key)
|
||||||
return int(match.group(1)) if match else None
|
return int(match.group(1)) if match else None
|
||||||
|
|
||||||
|
|
||||||
@@ -32,18 +35,21 @@ def iter_parameter_weights(model_path, device="mps"):
|
|||||||
:param device: Computing device
|
:param device: Computing device
|
||||||
:return: generator yielding (parameter key, parameter weight, layer index) tuples
|
:return: generator yielding (parameter key, parameter weight, layer index) tuples
|
||||||
"""
|
"""
|
||||||
shards = list(model_path.glob('model*.safetensors'))
|
shards = list(model_path.glob("model*.safetensors"))
|
||||||
if not shards:
|
if not shards:
|
||||||
raise ValueError(f"No model shards found in {model_path}")
|
raise ValueError(f"No model shards found in {model_path}")
|
||||||
|
|
||||||
for shard in tqdm(shards, desc="Processing shards"):
|
for shard in tqdm(shards, desc="Processing shards"):
|
||||||
with safetensors.safe_open(shard, framework='pt', device=device) as f:
|
with safetensors.safe_open(shard, framework="pt", device=device) as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
layer_idx = extract_layer_number(key)
|
layer_idx = extract_layer_number(key)
|
||||||
weight = f.get_tensor(key)
|
weight = f.get_tensor(key)
|
||||||
yield key, weight, layer_idx
|
yield key, weight, layer_idx
|
||||||
|
|
||||||
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12):
|
|
||||||
|
def iter_recursive_parameter_weights(
|
||||||
|
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
|
||||||
|
):
|
||||||
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
||||||
# to avoid precision loss when averaging weights across layers
|
# to avoid precision loss when averaging weights across layers
|
||||||
rrt_avg_model_state_dict = {}
|
rrt_avg_model_state_dict = {}
|
||||||
@@ -52,8 +58,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
|||||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||||
# get the matching module name in modules_to_recurse for the current parameter key
|
# get the matching module name in modules_to_recurse for the current parameter key
|
||||||
matched_module_name = next(
|
matched_module_name = next(
|
||||||
(module for module in modules_to_recurse if module in key),
|
(module for module in modules_to_recurse if module in key), None
|
||||||
None
|
|
||||||
)
|
)
|
||||||
if matched_module_name is None:
|
if matched_module_name is None:
|
||||||
continue
|
continue
|
||||||
@@ -64,7 +69,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
|||||||
# setup as storage for suffix with torch.stack
|
# setup as storage for suffix with torch.stack
|
||||||
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
|
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
|
||||||
else:
|
else:
|
||||||
rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu())
|
rrt_avg_model_state_dict[suffix].append(
|
||||||
|
weight.to(torch.float32).detach().cpu()
|
||||||
|
)
|
||||||
|
|
||||||
for module_name in modules_to_recurse:
|
for module_name in modules_to_recurse:
|
||||||
for recurse_idx in range(recurse_layers):
|
for recurse_idx in range(recurse_layers):
|
||||||
@@ -75,8 +82,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
|||||||
|
|
||||||
# compute the decomposed lora diff from the weight base to the actual weight for each module
|
# compute the decomposed lora diff from the weight base to the actual weight for each module
|
||||||
|
|
||||||
|
|
||||||
def low_rank_decomposition(
|
def low_rank_decomposition(
|
||||||
weight: torch.Tensor, max_rank: int
|
weight: torch.Tensor, max_rank: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Decompose a 2D matrix into low-rank matrices L and R using SVD.
|
Decompose a 2D matrix into low-rank matrices L and R using SVD.
|
||||||
@@ -86,10 +94,10 @@ def low_rank_decomposition(
|
|||||||
:return: A tuple of tensors (L, R)
|
:return: A tuple of tensors (L, R)
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
weight.dim() == 2
|
weight.dim() == 2
|
||||||
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
|
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
|
||||||
assert (
|
assert (
|
||||||
max_rank >= 1
|
max_rank >= 1
|
||||||
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
|
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
|
||||||
|
|
||||||
dtype = weight.dtype
|
dtype = weight.dtype
|
||||||
@@ -138,22 +146,31 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
|||||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||||
|
|
||||||
|
|
||||||
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
|
def iter_dora_parameter_weights(
|
||||||
|
model_path,
|
||||||
|
avg_recursive_weights,
|
||||||
|
modules_to_recurse: list[str],
|
||||||
|
alpha,
|
||||||
|
rank,
|
||||||
|
device="mps",
|
||||||
|
recurse_layers=12,
|
||||||
|
):
|
||||||
rrt_avg_model_state_dict = {}
|
rrt_avg_model_state_dict = {}
|
||||||
|
|
||||||
# iterate over all parameter weights in the model shards
|
# iterate over all parameter weights in the model shards
|
||||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||||
# get the matching module name in modules_to_recurse for the current parameter key
|
# get the matching module name in modules_to_recurse for the current parameter key
|
||||||
matched_module_name = next(
|
matched_module_name = next(
|
||||||
(module for module in modules_to_recurse if module in key),
|
(module for module in modules_to_recurse if module in key), None
|
||||||
None
|
|
||||||
)
|
)
|
||||||
if matched_module_name is None:
|
if matched_module_name is None:
|
||||||
if "input_layernorm" in key:
|
if "input_layernorm" in key:
|
||||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||||
loop_idx = layer_idx // recurse_layers
|
loop_idx = layer_idx // recurse_layers
|
||||||
layer_idx = layer_idx % recurse_layers
|
layer_idx = layer_idx % recurse_layers
|
||||||
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
|
layernorm_key = (
|
||||||
|
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
|
||||||
|
)
|
||||||
yield layernorm_key, weight
|
yield layernorm_key, weight
|
||||||
elif "post_attention_layernorm" in key:
|
elif "post_attention_layernorm" in key:
|
||||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||||
@@ -171,19 +188,26 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
|
|||||||
suffix = f"{layer_idx}.{matched_module_name}"
|
suffix = f"{layer_idx}.{matched_module_name}"
|
||||||
prefix = f"model.layers.{suffix}.weight_base"
|
prefix = f"model.layers.{suffix}.weight_base"
|
||||||
avg_weight = avg_recursive_weights[prefix]
|
avg_weight = avg_recursive_weights[prefix]
|
||||||
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
||||||
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
||||||
lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
lora_magnitude_key = (
|
||||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank)
|
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||||
|
)
|
||||||
|
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
|
||||||
|
weight, avg_weight, alpha, rank
|
||||||
|
)
|
||||||
yield lora_a_key, lora_a
|
yield lora_a_key, lora_a
|
||||||
yield lora_b_key, lora_b
|
yield lora_b_key, lora_b
|
||||||
yield lora_magnitude_key, lora_magnitude
|
yield lora_magnitude_key, lora_magnitude
|
||||||
|
|
||||||
|
|
||||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
weights_name = SAFE_WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||||
|
".safetensors", "{suffix}.safetensors"
|
||||||
|
)
|
||||||
state_dict_split = split_torch_state_dict_into_shards(
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
|
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
|
||||||
)
|
)
|
||||||
@@ -207,10 +231,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
|||||||
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
filename.startswith(weights_no_suffix)
|
filename.startswith(weights_no_suffix)
|
||||||
and os.path.isfile(full_filename)
|
and os.path.isfile(full_filename)
|
||||||
and filename not in state_dict_split.filename_to_tensors.keys()
|
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||||
and reg.fullmatch(filename_no_suffix) is not None
|
and reg.fullmatch(filename_no_suffix) is not None
|
||||||
):
|
):
|
||||||
os.remove(full_filename)
|
os.remove(full_filename)
|
||||||
|
|
||||||
@@ -221,7 +245,9 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
|||||||
shard[tensor] = state_dict[tensor].contiguous()
|
shard[tensor] = state_dict[tensor].contiguous()
|
||||||
del state_dict[tensor]
|
del state_dict[tensor]
|
||||||
|
|
||||||
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
save_file(
|
||||||
|
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
|
||||||
|
)
|
||||||
|
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
@@ -236,7 +262,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
|||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"):
|
|
||||||
|
def convert_llama_to_rrt(
|
||||||
|
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"
|
||||||
|
):
|
||||||
modules_to_recurse = [
|
modules_to_recurse = [
|
||||||
"self_attn.q_proj",
|
"self_attn.q_proj",
|
||||||
"self_attn.k_proj",
|
"self_attn.k_proj",
|
||||||
@@ -255,7 +284,14 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
|||||||
f"divisible by the recurse layers ({recurse_layers})"
|
f"divisible by the recurse layers ({recurse_layers})"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha})
|
config = RelaxedRecursiveLlamaConfig.from_dict(
|
||||||
|
{
|
||||||
|
**config.to_dict(),
|
||||||
|
"recurse_layers": recurse_layers,
|
||||||
|
"rank": rank,
|
||||||
|
"alpha": alpha,
|
||||||
|
}
|
||||||
|
)
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||||
|
|
||||||
@@ -263,13 +299,23 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
|||||||
rrt_model_state_dict = {}
|
rrt_model_state_dict = {}
|
||||||
|
|
||||||
logger.info(f"Calculating average recursive weights...")
|
logger.info(f"Calculating average recursive weights...")
|
||||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
|
for key, weight in iter_recursive_parameter_weights(
|
||||||
|
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
|
||||||
|
):
|
||||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
logger.info(f"Calculating decomposed lora diff...")
|
logger.info(f"Calculating decomposed lora diff...")
|
||||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||||
rrt_lora_state_dict = {}
|
rrt_lora_state_dict = {}
|
||||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
|
for key, weight in iter_dora_parameter_weights(
|
||||||
|
model_path,
|
||||||
|
rrt_model_state_dict,
|
||||||
|
modules_to_recurse,
|
||||||
|
alpha=32,
|
||||||
|
rank=rank,
|
||||||
|
device=device,
|
||||||
|
recurse_layers=recurse_layers,
|
||||||
|
):
|
||||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
# combine state dicts into a single state_dict
|
# combine state dicts into a single state_dict
|
||||||
@@ -287,4 +333,11 @@ if __name__ == "__main__":
|
|||||||
device = "cuda"
|
device = "cuda"
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)
|
convert_llama_to_rrt(
|
||||||
|
"meta-llama/Llama-3.2-1B",
|
||||||
|
"/tmp/rrt_model",
|
||||||
|
recurse_layers=4,
|
||||||
|
rank=256,
|
||||||
|
alpha=512,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user