@@ -33,14 +33,17 @@ def test_jac_is_stack_of_grads():
3333 y2 = a2 * x
3434 input = Gradients ({y1 : torch .ones_like (y1 ), y2 : torch .ones_like (y2 )})
3535
36- jac = Jac (outputs = [y1 , y2 ], inputs = [a1 , a2 ], chunk_size = None , retain_graph = True ) << Diagonalize (
37- [y1 , y2 ]
38- )
39- grad1 = Grad (outputs = [y1 ], inputs = [a1 , a2 ]) << Select ([y1 ], [y1 , y2 ])
40- grad2 = Grad (outputs = [y2 ], inputs = [a1 , a2 ]) << Select ([y2 ], [y1 , y2 ])
41- stack_of_grads = Stack ([grad1 , grad2 ])
36+ jac = Jac (outputs = [y1 , y2 ], inputs = [a1 , a2 ], chunk_size = None , retain_graph = True )
37+ diag = Diagonalize ([y1 , y2 ])
38+ jac_diag = jac << diag
39+
40+ grad1 = Grad (outputs = [y1 ], inputs = [a1 , a2 ])
41+ grad2 = Grad (outputs = [y2 ], inputs = [a1 , a2 ])
42+ select1 = Select ([y1 ], [y1 , y2 ])
43+ select2 = Select ([y2 ], [y1 , y2 ])
44+ stack_of_grads = Stack ([grad1 << select1 , grad2 << select2 ])
4245
43- jacobians = jac (input )
46+ jacobians = jac_diag (input )
4447 expected_jacobians = stack_of_grads (input )
4548
4649 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
@@ -66,7 +69,7 @@ def test_single_differentiation():
6669 assert_tensor_dicts_are_close (output , expected_output )
6770
6871
69- def test_multiple_differentiation_with_grad ():
72+ def test_multiple_differentiations ():
7073 """
7174 Tests that we can perform multiple scalar differentiations with the conjunction of multiple Grad
7275 transforms, composed with an Init transform.
@@ -78,10 +81,12 @@ def test_multiple_differentiation_with_grad():
7881 y2 = a2 * 3.0
7982 input = EmptyTensorDict ()
8083
81- grad1 = Grad ([y1 ], [a1 ]) << Select ([y1 ], [y1 , y2 ])
82- grad2 = Grad ([y2 ], [a2 ]) << Select ([y2 ], [y1 , y2 ])
84+ grad1 = Grad ([y1 ], [a1 ])
85+ grad2 = Grad ([y2 ], [a2 ])
86+ select1 = Select ([y1 ], [y1 , y2 ])
87+ select2 = Select ([y2 ], [y1 , y2 ])
8388 init = Init ([y1 , y2 ])
84- transform = (grad1 | grad2 ) << init
89+ transform = (( grad1 << select1 ) | ( grad2 << select2 ) ) << init
8590
8691 output = transform (input )
8792 expected_output = {
@@ -182,7 +187,8 @@ def test_conjunction_accumulate_select():
182187 """
183188 Tests that it is possible to conjunct an Accumulate and a Select in this order.
184189 It is not trivial since the type of the TensorDict returned by the first transform (Accumulate)
185- is EmptyDict, which is not the type that the conjunction should return (Gradients).
190+ is EmptyDict, which is not the type that the conjunction should return (Gradients), but a
191+ subclass of it.
186192 """
187193
188194 key = torch .tensor ([1.0 , 2.0 , 3.0 ], requires_grad = True , device = DEVICE )
@@ -199,7 +205,7 @@ def test_conjunction_accumulate_select():
199205 assert_tensor_dicts_are_close (output , expected_output )
200206
201207
202- def test_equivalence_jac_grad ():
208+ def test_equivalence_jac_grads ():
203209 """
204210 Tests that differentiation in parallel using `_jac` is equivalent to sequential differentiation
205211 using several calls to `_grad` and stacking the resulting gradients.
@@ -219,18 +225,12 @@ def test_equivalence_jac_grad():
219225 outputs = [y1 , y2 ]
220226 grad_outputs = [torch .ones_like (output ) for output in outputs ]
221227
222- grad_dict_1 = Grad (
223- outputs = [outputs [0 ]],
224- inputs = inputs ,
225- retain_graph = True ,
226- )(Gradients ({outputs [0 ]: grad_outputs [0 ]}))
228+ grad1 = Grad (outputs = [outputs [0 ]], inputs = inputs , retain_graph = True )
229+ grad_dict_1 = grad1 (Gradients ({outputs [0 ]: grad_outputs [0 ]}))
227230 grad_1_A , grad_1_b , grad_1_c = grad_dict_1 [A ], grad_dict_1 [b ], grad_dict_1 [c ]
228231
229- grad_dict_2 = Grad (
230- outputs = [outputs [1 ]],
231- inputs = inputs ,
232- retain_graph = True ,
233- )(Gradients ({outputs [1 ]: grad_outputs [1 ]}))
232+ grad2 = Grad (outputs = [outputs [1 ]], inputs = inputs , retain_graph = True )
233+ grad_dict_2 = grad2 (Gradients ({outputs [1 ]: grad_outputs [1 ]}))
234234 grad_2_A , grad_2_b , grad_2_c = grad_dict_2 [A ], grad_dict_2 [b ], grad_dict_2 [c ]
235235
236236 n_outputs = len (outputs )
@@ -240,11 +240,10 @@ def test_equivalence_jac_grad():
240240 for i , grad_output in enumerate (grad_outputs ):
241241 batched_grad_outputs [i ][i ] = grad_output
242242
243- jac_dict = Jac (
244- outputs = outputs ,
245- inputs = inputs ,
246- chunk_size = None ,
247- )(Jacobians ({outputs [0 ]: batched_grad_outputs [0 ], outputs [1 ]: batched_grad_outputs [1 ]}))
243+ jac = Jac (outputs = outputs , inputs = inputs , chunk_size = None )
244+ jac_dict = jac (
245+ Jacobians ({outputs [0 ]: batched_grad_outputs [0 ], outputs [1 ]: batched_grad_outputs [1 ]})
246+ )
248247 jac_A , jac_b , jac_c = jac_dict [A ], jac_dict [b ], jac_dict [c ]
249248
250249 assert_close (jac_A , torch .stack ([grad_1_A , grad_2_A ]))
0 commit comments