@@ -18,12 +18,14 @@ def close(x, y):
1818class TestMAMLAlgorithm (unittest .TestCase ):
1919
2020 def setUp (self ):
21- self .model = torch .nn .Sequential (torch .nn .Linear (INPUT_SIZE , HIDDEN_SIZE ),
22- torch .nn .ReLU (),
23- torch .nn .Linear (HIDDEN_SIZE , HIDDEN_SIZE ),
24- torch .nn .Sigmoid (),
25- torch .nn .Linear (HIDDEN_SIZE , HIDDEN_SIZE ),
26- torch .nn .Softmax ())
21+ self .model = torch .nn .Sequential (
22+ torch .nn .Linear (INPUT_SIZE , HIDDEN_SIZE ),
23+ torch .nn .ReLU (),
24+ torch .nn .Linear (HIDDEN_SIZE , HIDDEN_SIZE ),
25+ torch .nn .Sigmoid (),
26+ torch .nn .Linear (HIDDEN_SIZE , HIDDEN_SIZE ),
27+ torch .nn .Softmax (),
28+ )
2729
2830 self .model .register_buffer ('dummy_buf' , torch .zeros (1 , 2 , 3 , 4 ))
2931
@@ -101,7 +103,11 @@ def test_allow_nograd(self):
101103 try :
102104 # Check that without allow_nograd, adaptation fails
103105 clone .adapt (loss )
104- self .assertTrue (False , 'adaptation successful despite requires_grad=False' ) # Check that execution never gets here
106+ # Check that execution never gets here
107+ self .assertTrue (
108+ False ,
109+ 'adaptation successful despite requires_grad=False' ,
110+ )
105111 except :
106112 # Check that with allow_nograd, adaptation succeeds
107113 clone .adapt (loss , allow_nograd = True )
@@ -112,17 +118,50 @@ def test_allow_nograd(self):
112118 if p .requires_grad :
113119 self .assertTrue (p .grad is not None )
114120
115- maml = l2l .algorithms .MAML (self .model ,
116- lr = INNER_LR ,
117- first_order = False ,
118- allow_nograd = True )
121+ maml = l2l .algorithms .MAML (
122+ self .model ,
123+ lr = INNER_LR ,
124+ first_order = False ,
125+ allow_nograd = True ,
126+ )
119127 clone = maml .clone ()
120128 loss = sum ([p .norm (p = 2 ) for p in clone .parameters ()])
121129 # Check that without allow_nograd, adaptation succeeds thanks to init.
122130 orig_weight = self .model [2 ].weight .clone ().detach ()
123131 clone .adapt (loss )
124132 self .assertTrue (close (orig_weight , self .model [2 ].weight ))
125133
134+ def test_module_shared_params (self ):
135+
136+ class TestModule (torch .nn .Module ):
137+ def __init__ (self ):
138+ super (TestModule , self ).__init__ ()
139+ cnn = [
140+ torch .nn .Conv2d (3 , 32 , 3 , 2 , 1 ),
141+ torch .nn .ReLU (),
142+ torch .nn .Conv2d (32 , 32 , 3 , 2 , 1 ),
143+ torch .nn .ReLU (),
144+ torch .nn .Conv2d (32 , 32 , 3 , 2 , 1 ),
145+ torch .nn .ReLU (),
146+ ]
147+ self .seq = torch .nn .Sequential (* cnn )
148+ self .head = torch .nn .Sequential (* [
149+ torch .nn .Conv2d (32 , 32 , 3 , 2 , 1 ),
150+ torch .nn .ReLU (),
151+ torch .nn .Conv2d (32 , 100 , 3 , 2 , 1 )]
152+ )
153+ self .net = torch .nn .Sequential (self .seq , self .head )
154+
155+ def forward (self , x ):
156+ return self .net (x )
157+
158+ module = TestModule ()
159+ maml = l2l .algorithms .MAML (module , lr = 0.1 )
160+ clone = maml .clone ()
161+ loss = sum (p .norm (p = 2 ) for p in clone .parameters ())
162+ clone .adapt (loss )
163+ loss = sum (p .norm (p = 2 ) for p in clone .parameters ())
164+ loss .backward ()
126165
127166
128167if __name__ == '__main__' :
0 commit comments