Skip to content

Commit 6a32de2

Browse files
authored
Fix clone_module when submodules share parameters. (#176)
* Fix clone_module with shared parameters. * Add _notravis for benchmarks too. * Update CHANGELOG.
1 parent 69558e0 commit 6a32de2

5 files changed

Lines changed: 123 additions & 17 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919

2020
### Fixed
2121

22+
* Fix `clone_module` for Modules whose submodules share parameters.
23+
2224

2325
## v0.1.2
2426

learn2learn/utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def clone_parameters(param_list):
4848
return [p.clone() for p in param_list]
4949

5050

51-
def clone_module(module):
51+
def clone_module(module, memo=None):
5252
"""
5353
5454
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py)
@@ -91,6 +91,12 @@ def clone_module(module):
9191
# clone = recursive_shallow_copy(model)
9292
# clone._apply(lambda t: t.clone())
9393

94+
if memo is None:
95+
# Maps original data_ptr to the cloned tensor.
96+
# Useful when a Module uses parameters from another Module; see:
97+
# https://github.com/learnables/learn2learn/issues/174
98+
memo = {}
99+
94100
# First, create a copy of the module.
95101
# Adapted from:
96102
# https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171
@@ -106,20 +112,36 @@ def clone_module(module):
106112
if hasattr(clone, '_parameters'):
107113
for param_key in module._parameters:
108114
if module._parameters[param_key] is not None:
109-
cloned = module._parameters[param_key].clone()
110-
clone._parameters[param_key] = cloned
115+
param = module._parameters[param_key]
116+
param_ptr = param.data_ptr
117+
if param_ptr in memo:
118+
clone._parameters[param_key] = memo[param_ptr]
119+
else:
120+
cloned = param.clone()
121+
clone._parameters[param_key] = cloned
122+
memo[param_ptr] = cloned
111123

112124
# Third, handle the buffers if necessary
113125
if hasattr(clone, '_buffers'):
114126
for buffer_key in module._buffers:
115127
if clone._buffers[buffer_key] is not None and \
116128
clone._buffers[buffer_key].requires_grad:
117-
clone._buffers[buffer_key] = module._buffers[buffer_key].clone()
129+
buff = module._buffers[buffer_key]
130+
buff_ptr = buff.data_ptr
131+
if buff_ptr in memo:
132+
clone._buffers[buffer_key] = memo[buff_ptr]
133+
else:
134+
cloned = buff.clone()
135+
clone._buffers[buffer_key] = cloned
136+
memo[param_ptr] = cloned
118137

119138
# Then, recurse for each submodule
120139
if hasattr(clone, '_modules'):
121140
for module_key in clone._modules:
122-
clone._modules[module_key] = clone_module(module._modules[module_key])
141+
clone._modules[module_key] = clone_module(
142+
module._modules[module_key],
143+
memo=memo,
144+
)
123145

124146
# Finally, rebuild the flattened parameters for RNNs
125147
# See this issue for more details:

tests/unit/algorithms/maml_test.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ def close(x, y):
1818
class TestMAMLAlgorithm(unittest.TestCase):
1919

2020
def setUp(self):
21-
self.model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
22-
torch.nn.ReLU(),
23-
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
24-
torch.nn.Sigmoid(),
25-
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
26-
torch.nn.Softmax())
21+
self.model = torch.nn.Sequential(
22+
torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
23+
torch.nn.ReLU(),
24+
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
25+
torch.nn.Sigmoid(),
26+
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
27+
torch.nn.Softmax(),
28+
)
2729

2830
self.model.register_buffer('dummy_buf', torch.zeros(1, 2, 3, 4))
2931

@@ -101,7 +103,11 @@ def test_allow_nograd(self):
101103
try:
102104
# Check that without allow_nograd, adaptation fails
103105
clone.adapt(loss)
104-
self.assertTrue(False, 'adaptation successful despite requires_grad=False') # Check that execution never gets here
106+
# Check that execution never gets here
107+
self.assertTrue(
108+
False,
109+
'adaptation successful despite requires_grad=False',
110+
)
105111
except:
106112
# Check that with allow_nograd, adaptation succeeds
107113
clone.adapt(loss, allow_nograd=True)
@@ -112,17 +118,50 @@ def test_allow_nograd(self):
112118
if p.requires_grad:
113119
self.assertTrue(p.grad is not None)
114120

115-
maml = l2l.algorithms.MAML(self.model,
116-
lr=INNER_LR,
117-
first_order=False,
118-
allow_nograd=True)
121+
maml = l2l.algorithms.MAML(
122+
self.model,
123+
lr=INNER_LR,
124+
first_order=False,
125+
allow_nograd=True,
126+
)
119127
clone = maml.clone()
120128
loss = sum([p.norm(p=2) for p in clone.parameters()])
121129
# Check that without allow_nograd, adaptation succeeds thanks to init.
122130
orig_weight = self.model[2].weight.clone().detach()
123131
clone.adapt(loss)
124132
self.assertTrue(close(orig_weight, self.model[2].weight))
125133

134+
def test_module_shared_params(self):
135+
136+
class TestModule(torch.nn.Module):
137+
def __init__(self):
138+
super(TestModule, self).__init__()
139+
cnn = [
140+
torch.nn.Conv2d(3, 32, 3, 2, 1),
141+
torch.nn.ReLU(),
142+
torch.nn.Conv2d(32, 32, 3, 2, 1),
143+
torch.nn.ReLU(),
144+
torch.nn.Conv2d(32, 32, 3, 2, 1),
145+
torch.nn.ReLU(),
146+
]
147+
self.seq = torch.nn.Sequential(*cnn)
148+
self.head = torch.nn.Sequential(*[
149+
torch.nn.Conv2d(32, 32, 3, 2, 1),
150+
torch.nn.ReLU(),
151+
torch.nn.Conv2d(32, 100, 3, 2, 1)]
152+
)
153+
self.net = torch.nn.Sequential(self.seq, self.head)
154+
155+
def forward(self, x):
156+
return self.net(x)
157+
158+
module = TestModule()
159+
maml = l2l.algorithms.MAML(module, lr=0.1)
160+
clone = maml.clone()
161+
loss = sum(p.norm(p=2) for p in clone.parameters())
162+
clone.adapt(loss)
163+
loss = sum(p.norm(p=2) for p in clone.parameters())
164+
loss.backward()
126165

127166

128167
if __name__ == '__main__':

tests/unit/utils_test.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def ref_clone_module(module):
1313
each forward call.
1414
See this issue for more details:
1515
https://github.com/learnables/learn2learn/issues/139
16+
17+
Note: This implementation also does not work for Modules that re-use
18+
parameters from another Module.
19+
See this issue for more details:
20+
https://github.com/learnables/learn2learn/issues/174
1621
"""
1722
# First, create a copy of the module.
1823
clone = copy.deepcopy(module)
@@ -191,10 +196,48 @@ def test_rnn_clone(self):
191196
# Ensure we did better
192197
self.assertTrue(first_loss > second_loss)
193198

199+
def test_module_clone_shared_params(self):
200+
# Tests proper use of memo parameter
201+
202+
class TestModule(torch.nn.Module):
203+
def __init__(self):
204+
super(TestModule, self).__init__()
205+
cnn = [
206+
torch.nn.Conv2d(3, 32, 3, 2, 1),
207+
torch.nn.ReLU(),
208+
torch.nn.Conv2d(32, 32, 3, 2, 1),
209+
torch.nn.ReLU(),
210+
torch.nn.Conv2d(32, 32, 3, 2, 1),
211+
torch.nn.ReLU(),
212+
]
213+
self.seq = torch.nn.Sequential(*cnn)
214+
self.head = torch.nn.Sequential(*[
215+
torch.nn.Conv2d(32, 32, 3, 2, 1),
216+
torch.nn.ReLU(),
217+
torch.nn.Conv2d(32, 100, 3, 2, 1)]
218+
)
219+
self.net = torch.nn.Sequential(self.seq, self.head)
220+
221+
def forward(self, x):
222+
return self.net(x)
223+
224+
original = TestModule()
225+
clone = l2l.clone_module(original)
226+
self.assertTrue(
227+
len(list(clone.parameters())) == len(list(original.parameters())),
228+
'clone and original do not have same number of parameters.',
229+
)
230+
231+
orig_params = [p.data_ptr() for p in original.parameters()]
232+
duplicates = [p.data_ptr() in orig_params for p in clone.parameters()]
233+
self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')
194234

195235
def test_module_detach(self):
196236
original_output = self.model(self.input)
197-
original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]]))
237+
original_loss = self.loss_func(
238+
original_output,
239+
torch.tensor([[0., 0.]])
240+
)
198241

199242
original_gradients = torch.autograd.grad(original_loss,
200243
self.model.parameters(),
File renamed without changes.

0 commit comments

Comments
 (0)