From 474ba1a1b82791bbc3d90a41051f0b9261fbf48c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 12:20:20 -0500 Subject: [PATCH] chore: lint/formatting --- src/axolotl/integrations/rrt/__init__.py | 11 +- src/axolotl/integrations/rrt/cli/convert.py | 123 ++++++++++++++------ 2 files changed, 96 insertions(+), 38 deletions(-) diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py index 35d4106b7..04372409e 100644 --- a/src/axolotl/integrations/rrt/__init__.py +++ b/src/axolotl/integrations/rrt/__init__.py @@ -8,8 +8,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from axolotl.integrations.base import BasePlugin from axolotl.integrations.rrt.modeling import register_rrt_model -from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \ - RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM +from axolotl.integrations.rrt.modeling.modeling_rrt_llama import ( + RelaxedRecursiveLlamaConfig, + RelaxedRecursiveLlamaForCausalLM, + RelaxedRecursiveLlamaModel, +) LOG = logging.getLogger(__name__) @@ -39,4 +42,6 @@ def register_rrt_model(): # Register models AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) - AutoModelForCausalLM.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM) + AutoModelForCausalLM.register( + RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM + ) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 3418e8d4a..e408dc66d 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -12,15 +12,18 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard from safetensors.torch import save_file from tqdm import tqdm from transformers import AutoConfig -from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME -from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig +from axolotl.integrations.rrt.modeling.modeling_rrt_llama import ( + RelaxedRecursiveLlamaConfig, +) logger = logging.getLogger(__name__) + def extract_layer_number(key): """Extract layer number from parameter key.""" - match = re.search(r'layers\.(\d+)\.', key) + match = re.search(r"layers\.(\d+)\.", key) return int(match.group(1)) if match else None @@ -32,18 +35,21 @@ def iter_parameter_weights(model_path, device="mps"): :param device: Computing device :return: generator yielding (parameter key, parameter weight, layer index) tuples """ - shards = list(model_path.glob('model*.safetensors')) + shards = list(model_path.glob("model*.safetensors")) if not shards: raise ValueError(f"No model shards found in {model_path}") for shard in tqdm(shards, desc="Processing shards"): - with safetensors.safe_open(shard, framework='pt', device=device) as f: - for key in f.keys(): - layer_idx = extract_layer_number(key) - weight = f.get_tensor(key) - yield key, weight, layer_idx + with safetensors.safe_open(shard, framework="pt", device=device) as f: + for key in f.keys(): + layer_idx = extract_layer_number(key) + weight = f.get_tensor(key) + yield key, weight, layer_idx -def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12): + +def iter_recursive_parameter_weights( + model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12 +): # setup placeholder state_dict for recursive weights, need to keep in float32 precision # to avoid precision loss when averaging weights across layers rrt_avg_model_state_dict = {} @@ -52,8 +58,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): # get the matching module name in modules_to_recurse for the current parameter key matched_module_name = next( - (module for module in modules_to_recurse if module in key), - None + (module for module in modules_to_recurse if module in key), None ) if matched_module_name is None: continue @@ -64,7 +69,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], # setup as storage for suffix with torch.stack rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()] else: - rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu()) + rrt_avg_model_state_dict[suffix].append( + weight.to(torch.float32).detach().cpu() + ) for module_name in modules_to_recurse: for recurse_idx in range(recurse_layers): @@ -75,8 +82,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], # compute the decomposed lora diff from the weight base to the actual weight for each module + def low_rank_decomposition( - weight: torch.Tensor, max_rank: int + weight: torch.Tensor, max_rank: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decompose a 2D matrix into low-rank matrices L and R using SVD. @@ -86,10 +94,10 @@ def low_rank_decomposition( :return: A tuple of tensors (L, R) """ assert ( - weight.dim() == 2 + weight.dim() == 2 ), f"Only support 2D matrix, but input has {weight.dim()} dimensions." assert ( - max_rank >= 1 + max_rank >= 1 ), f"Maximum rank must be a positive integer, but input max_rank={max_rank}." dtype = weight.dtype @@ -138,22 +146,31 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank): return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu() -def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12): +def iter_dora_parameter_weights( + model_path, + avg_recursive_weights, + modules_to_recurse: list[str], + alpha, + rank, + device="mps", + recurse_layers=12, +): rrt_avg_model_state_dict = {} # iterate over all parameter weights in the model shards for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): # get the matching module name in modules_to_recurse for the current parameter key matched_module_name = next( - (module for module in modules_to_recurse if module in key), - None + (module for module in modules_to_recurse if module in key), None ) if matched_module_name is None: if "input_layernorm" in key: # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx loop_idx = layer_idx // recurse_layers layer_idx = layer_idx % recurse_layers - layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight" + layernorm_key = ( + f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight" + ) yield layernorm_key, weight elif "post_attention_layernorm" in key: # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx @@ -171,19 +188,26 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re suffix = f"{layer_idx}.{matched_module_name}" prefix = f"model.layers.{suffix}.weight_base" avg_weight = avg_recursive_weights[prefix] - lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}" - lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}" - lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}" - lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank) + lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}" + lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}" + lora_magnitude_key = ( + f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}" + ) + lora_a, lora_b, lora_magnitude = decompose_delta_weight( + weight, avg_weight, alpha, rank + ) yield lora_a_key, lora_a yield lora_b_key, lora_b yield lora_magnitude_key, lora_magnitude + def save_state_dict_to_safetensors(state_dict, save_directory): os.makedirs(save_directory, exist_ok=True) weights_name = SAFE_WEIGHTS_NAME - filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size="1GB" ) @@ -207,10 +231,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory): reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") if ( - filename.startswith(weights_no_suffix) - and os.path.isfile(full_filename) - and filename not in state_dict_split.filename_to_tensors.keys() - and reg.fullmatch(filename_no_suffix) is not None + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in state_dict_split.filename_to_tensors.keys() + and reg.fullmatch(filename_no_suffix) is not None ): os.remove(full_filename) @@ -221,7 +245,9 @@ def save_state_dict_to_safetensors(state_dict, save_directory): shard[tensor] = state_dict[tensor].contiguous() del state_dict[tensor] - save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + save_file( + shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"} + ) del state_dict @@ -236,7 +262,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory): content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) -def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"): + +def convert_llama_to_rrt( + model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps" +): modules_to_recurse = [ "self_attn.q_proj", "self_attn.k_proj", @@ -255,7 +284,14 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank= f"divisible by the recurse layers ({recurse_layers})" ) - config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha}) + config = RelaxedRecursiveLlamaConfig.from_dict( + { + **config.to_dict(), + "recurse_layers": recurse_layers, + "rank": rank, + "alpha": alpha, + } + ) config.save_pretrained(output_dir) model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth")) @@ -263,13 +299,23 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank= rrt_model_state_dict = {} logger.info(f"Calculating average recursive weights...") - for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers): + for key, weight in iter_recursive_parameter_weights( + model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers + ): rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() logger.info(f"Calculating decomposed lora diff...") # now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff rrt_lora_state_dict = {} - for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers): + for key, weight in iter_dora_parameter_weights( + model_path, + rrt_model_state_dict, + modules_to_recurse, + alpha=32, + rank=rank, + device=device, + recurse_layers=recurse_layers, + ): rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() # combine state dicts into a single state_dict @@ -287,4 +333,11 @@ if __name__ == "__main__": device = "cuda" else: device = "cpu" - convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device) + convert_llama_to_rrt( + "meta-llama/Llama-3.2-1B", + "/tmp/rrt_model", + recurse_layers=4, + rank=256, + alpha=512, + device=device, + )