Train parameters exclusively in specific ranges (#1390)
* Train parameters exclusively in specific ranges * Fix the style and update docs * Update yaml example
This commit is contained in:
285
tests/test_freeze.py
Normal file
285
tests/test_freeze.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
This module contains unit tests for the `freeze_layers_except` function.
|
||||
|
||||
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
||||
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
|
||||
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
|
||||
|
||||
class TestFreezeLayersExcept(unittest.TestCase):
|
||||
"""
|
||||
A test case class for the `freeze_layers_except` function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.model = _TestModel()
|
||||
|
||||
def test_freeze_layers_with_dots_in_name(self):
|
||||
freeze_layers_except(self.model, ["features.layer"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_without_dots_in_name(self):
|
||||
freeze_layers_except(self.model, ["classifier"])
|
||||
self.assertFalse(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_regex_patterns(self):
|
||||
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
||||
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_all_layers_frozen(self):
|
||||
freeze_layers_except(self.model, [])
|
||||
self.assertFalse(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be frozen.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_all_layers_unfrozen(self):
|
||||
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be trainable.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_start_end(self):
|
||||
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_single_index(self):
|
||||
freeze_layers_except(self.model, ["features.layer[5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
||||
freeze_layers_except(self.model, ["features.layer[:5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_included(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
||||
freeze_layers_except(
|
||||
self.model,
|
||||
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_gradient_output(self, expected):
|
||||
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
||||
|
||||
self.model.features.layer.weight.grad = None # Reset gradients
|
||||
output = self.model.features.layer(input_tensor)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
expected_grads = torch.tensor(expected)
|
||||
torch.testing.assert_close(
|
||||
self.model.features.layer.weight.grad, expected_grads
|
||||
)
|
||||
|
||||
|
||||
class _SubLayerModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(10, 10)
|
||||
|
||||
|
||||
class _TestModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = _SubLayerModule()
|
||||
self.classifier = nn.Linear(10, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user