Train parameters exclusively in specific ranges (#1390)
* Train parameters exclusively in specific ranges * Fix the style and update docs * Update yaml example
This commit is contained in:
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
|
|||||||
|
|
||||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
# - lm_head.*
|
# - ^lm_head.weight$
|
||||||
# - model.embed_tokens.*
|
# - ^model.embed_tokens.weight$[:32000]
|
||||||
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
||||||
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_parameters_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ def train(
|
|||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cfg.unfrozen_parameters:
|
if cfg.unfrozen_parameters:
|
||||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg,
|
cfg,
|
||||||
|
|||||||
@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable, List, Tuple
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||||
|
|
||||||
|
|
||||||
def freeze_parameters_except(model, regex_patterns):
|
def freeze_layers_except(model, regex_patterns):
|
||||||
"""
|
"""
|
||||||
Freezes all layers of the given model except for the layers that match given regex patterns.
|
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||||
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||||
@@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns):
|
|||||||
Parameters:
|
Parameters:
|
||||||
- model (nn.Module): The PyTorch model to be modified.
|
- model (nn.Module): The PyTorch model to be modified.
|
||||||
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||||
|
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
|
||||||
|
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
|
||||||
|
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
|
||||||
|
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None; the model is modified in place.
|
None; the model is modified in place.
|
||||||
"""
|
"""
|
||||||
# Escape periods and compile the regex patterns
|
if isinstance(regex_patterns, str):
|
||||||
compiled_patterns = [
|
regex_patterns = [regex_patterns]
|
||||||
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
|
||||||
]
|
|
||||||
|
|
||||||
# First, freeze all parameters in the model
|
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
|
||||||
for param in model.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# Unfreeze layers that match the regex patterns
|
# Unfreeze layers that match the regex patterns
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(pattern.match(name) for pattern in compiled_patterns):
|
param.requires_grad = False
|
||||||
if is_main_process():
|
unfrozen_ranges = []
|
||||||
LOG.debug(f"unfreezing {name}")
|
for pattern in patterns:
|
||||||
|
if not pattern.match(name):
|
||||||
|
continue
|
||||||
|
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
|
if pattern.range is not None:
|
||||||
|
unfrozen_ranges.append(pattern.range)
|
||||||
|
|
||||||
|
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
|
||||||
|
|
||||||
|
if param.requires_grad and is_main_process():
|
||||||
|
unfrozen_ranges = (
|
||||||
|
f" with ranges {merged_unfrozen_ranges}"
|
||||||
|
if merged_unfrozen_ranges
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
|
||||||
|
|
||||||
|
if not merged_unfrozen_ranges:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The range list we need is actually the inverted of the merged ranges
|
||||||
|
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
|
||||||
|
|
||||||
|
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
|
||||||
|
|
||||||
|
if is_main_process() and all(
|
||||||
|
not param.requires_grad for param in model.parameters()
|
||||||
|
):
|
||||||
|
LOG.warning("All parameters are frozen. Model will not be trained.")
|
||||||
|
|
||||||
|
|
||||||
|
def _invert_ranges(
|
||||||
|
given_ranges: List[Tuple[int, int]], layer_size: int
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
||||||
|
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
||||||
|
Returns:
|
||||||
|
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
||||||
|
"""
|
||||||
|
if not given_ranges:
|
||||||
|
return [(0, layer_size)]
|
||||||
|
|
||||||
|
inverted_ranges = []
|
||||||
|
current_start = 0
|
||||||
|
|
||||||
|
for start, end in sorted(given_ranges):
|
||||||
|
if start > current_start:
|
||||||
|
inverted_ranges.append((current_start, start))
|
||||||
|
current_start = max(current_start, end)
|
||||||
|
|
||||||
|
# Handle the case where the last given range does not reach the end of the total_size
|
||||||
|
if current_start < layer_size:
|
||||||
|
inverted_ranges.append((current_start, layer_size))
|
||||||
|
|
||||||
|
return inverted_ranges
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_ranges(
|
||||||
|
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Merges overlapping ranges and sorts the given ranges.
|
||||||
|
|
||||||
|
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
|
||||||
|
as tuples, where the first element is the start index (inclusive) and the second element is the end
|
||||||
|
index (exclusive). The end index can be None, indicating that the range extends to the end of the
|
||||||
|
sequence.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
|
||||||
|
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
|
||||||
|
"""
|
||||||
|
# End of each range can be determined now since we have the total size
|
||||||
|
processed_ranges = [
|
||||||
|
(start, end if end is not None else layer_size) for start, end in given_ranges
|
||||||
|
]
|
||||||
|
|
||||||
|
# No need to merge if there's only one or no ranges
|
||||||
|
if len(processed_ranges) <= 1:
|
||||||
|
return processed_ranges
|
||||||
|
|
||||||
|
sorted_ranges = sorted(processed_ranges)
|
||||||
|
|
||||||
|
merged_ranges = [sorted_ranges[0]]
|
||||||
|
for start, end in sorted_ranges[1:]:
|
||||||
|
prev_start, prev_end = merged_ranges[-1]
|
||||||
|
if start <= prev_end:
|
||||||
|
merged_ranges[-1] = (prev_start, max(prev_end, end))
|
||||||
|
else:
|
||||||
|
merged_ranges.append((start, end))
|
||||||
|
|
||||||
|
return merged_ranges
|
||||||
|
|
||||||
|
|
||||||
|
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
|
||||||
|
"""
|
||||||
|
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
|
||||||
|
|
||||||
|
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
|
||||||
|
two integers representing the start and end indices of the range.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Callable: A hook function to be used with `register_hook` on parameters.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
ranges_to_freeze = [(0, 10), (20, 30)]
|
||||||
|
hook = _create_freeze_parameters_hook(ranges_to_freeze)
|
||||||
|
model.register_hook(hook)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def freeze_parameters_hook(gradients):
|
||||||
|
for start, end in ranges_to_freeze:
|
||||||
|
gradients[start:end].zero_()
|
||||||
|
|
||||||
|
return freeze_parameters_hook
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNamePattern:
|
||||||
|
"""
|
||||||
|
Represents a regex pattern for layer names, potentially including a parameter index range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern: str):
|
||||||
|
"""
|
||||||
|
Initializes a new instance of the LayerNamePattern class.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
|
||||||
|
"""
|
||||||
|
self.raw_pattern = pattern
|
||||||
|
name_pattern, self.range = self._parse_pattern(pattern)
|
||||||
|
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
||||||
|
|
||||||
|
def match(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the given layer name matches the regex pattern.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- name (str): The layer name to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- bool: True if the layer name matches the pattern, False otherwise.
|
||||||
|
"""
|
||||||
|
return self.name_regex.match(name) is not None
|
||||||
|
|
||||||
|
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
||||||
|
"""
|
||||||
|
Extracts the range pattern from the given pattern.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- pattern (str): The pattern to extract the range from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
|
||||||
|
"""
|
||||||
|
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
|
||||||
|
if not match:
|
||||||
|
return pattern, None
|
||||||
|
|
||||||
|
base_pattern, start_part, end_part = match.groups()
|
||||||
|
|
||||||
|
if end_part is None and start_part.isdecimal():
|
||||||
|
index = int(start_part)
|
||||||
|
return base_pattern, (index, index + 1)
|
||||||
|
|
||||||
|
# [:end] or [start:] or [start:end]
|
||||||
|
start = int(start_part) if start_part else 0
|
||||||
|
end = int(end_part) if end_part else None
|
||||||
|
|
||||||
|
if end is not None and start >= end:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid range in layer name pattern: {pattern}."
|
||||||
|
"End of range must be greater than start."
|
||||||
|
)
|
||||||
|
return base_pattern, (start, end)
|
||||||
|
|||||||
285
tests/test_freeze.py
Normal file
285
tests/test_freeze.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
This module contains unit tests for the `freeze_layers_except` function.
|
||||||
|
|
||||||
|
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
||||||
|
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
|
|
||||||
|
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFreezeLayersExcept(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
A test case class for the `freeze_layers_except` function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model = _TestModel()
|
||||||
|
|
||||||
|
def test_freeze_layers_with_dots_in_name(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_without_dots_in_name(self):
|
||||||
|
freeze_layers_except(self.model, ["classifier"])
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_regex_patterns(self):
|
||||||
|
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
||||||
|
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_all_layers_frozen(self):
|
||||||
|
freeze_layers_except(self.model, [])
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be frozen.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_all_layers_unfrozen(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be trainable.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_start_end(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_single_index(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer[5]"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer[:5]"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer[4:]"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_merge_included(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
||||||
|
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
||||||
|
freeze_layers_except(
|
||||||
|
self.model,
|
||||||
|
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
self.model.features.layer.weight.requires_grad,
|
||||||
|
"model.features.layer should be trainable.",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
self.model.classifier.weight.requires_grad,
|
||||||
|
"model.classifier should be frozen.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._assert_gradient_output(
|
||||||
|
[
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ZERO,
|
||||||
|
ONE_TO_TEN,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
ZERO,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assert_gradient_output(self, expected):
|
||||||
|
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
||||||
|
|
||||||
|
self.model.features.layer.weight.grad = None # Reset gradients
|
||||||
|
output = self.model.features.layer(input_tensor)
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
expected_grads = torch.tensor(expected)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
self.model.features.layer.weight.grad, expected_grads
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _SubLayerModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = nn.Linear(10, 10)
|
||||||
|
|
||||||
|
|
||||||
|
class _TestModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.features = _SubLayerModule()
|
||||||
|
self.classifier = nn.Linear(10, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user