Skip to content

Commit 34deb43

Browse files
authored
Maintain autojac unit tests (#213)
* Shorten some tests * Improve clarity of some tests * Improve comments and docstrings * Uniformise and improve test names * Uniformise variable names * Remove unused variable b in test_composition_of_jacs_is_jac
1 parent 56f319f commit 34deb43

13 files changed

Lines changed: 349 additions & 376 deletions

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,34 +64,33 @@ def test_multiple_accumulation(iterations: int):
6464
assert_tensor_dicts_are_close(grads, expected_grads)
6565

6666

67-
def test_accumulate_fails_on_no_requires_grad():
67+
def test_no_requires_grad_fails():
6868
"""
6969
Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
7070
tensor that does not require grad.
7171
"""
7272

73-
key1 = torch.zeros([1], requires_grad=False, device=DEVICE)
74-
value1 = torch.ones([1], device=DEVICE)
75-
input = Gradients({key1: value1})
73+
key = torch.zeros([1], requires_grad=False, device=DEVICE)
74+
value = torch.ones([1], device=DEVICE)
75+
input = Gradients({key: value})
7676

77-
accumulate = Accumulate([key1])
77+
accumulate = Accumulate([key])
7878

7979
with raises(ValueError):
8080
accumulate(input)
8181

8282

83-
def test_accumulate_fails_on_no_leaf_and_no_retains_grad():
83+
def test_no_leaf_and_no_retains_grad_fails():
8484
"""
8585
Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
8686
tensor that is not a leaf and that does not retain grad.
8787
"""
8888

89-
a = torch.tensor([1.0], requires_grad=True, device=DEVICE)
90-
key1 = 2 * a # requires_grad=True, but is_leaf=False and retains_grad=False
91-
value1 = torch.ones([1], device=DEVICE)
92-
input = Gradients({key1: value1})
89+
key = torch.tensor([1.0], requires_grad=True, device=DEVICE) * 2
90+
value = torch.ones([1], device=DEVICE)
91+
input = Gradients({key: value})
9392

94-
accumulate = Accumulate([key1])
93+
accumulate = Accumulate([key])
9594

9695
with raises(ValueError):
9796
accumulate(input)

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def test_aggregate_matrices_output_structure(jacobian_matrices: JacobianMatrices
5555
assert set(jacobian_matrices.keys()) == set(gradient_vectors.keys())
5656

5757
for key in jacobian_matrices.keys():
58-
jacobian_matrix = jacobian_matrices[key]
59-
gradient_vector = gradient_vectors[key]
60-
assert gradient_vector.numel() == jacobian_matrix[0].numel()
58+
assert gradient_vectors[key].numel() == jacobian_matrices[key][0].numel()
6159

6260

6361
def test_aggregate_matrices_empty_dict():

tests/unit/autojac/_transform/test_base.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __str__(self):
2323
return "T"
2424

2525
def _compute(self, input: _B) -> _C:
26-
# ignore the input, create a dictionary with the right keys as an output.
27-
# cast the type for the purpose of type-checking.
26+
# Ignore the input, create a dictionary with the right keys as an output.
27+
# Cast the type for the purpose of type-checking.
2828
output_dict = {key: torch.empty(0, device=DEVICE) for key in self._output_keys}
2929
return typing.cast(_C, output_dict)
3030

@@ -37,90 +37,90 @@ def output_keys(self) -> set[Tensor]:
3737
return self._output_keys
3838

3939

40-
def test_apply_keys():
40+
def test_call_checks_keys():
4141
"""
42-
Tests that a ``Transform`` checks that the provided dictionary to the `__apply__` function
42+
Tests that a ``Transform`` checks that the provided dictionary to the `__call__` function
4343
contains keys that correspond exactly to `required_keys`.
4444
"""
4545

46-
t1 = torch.randn([2], device=DEVICE)
47-
t2 = torch.randn([3], device=DEVICE)
48-
transform = FakeTransform({t1}, {t1, t2})
46+
a1 = torch.randn([2], device=DEVICE)
47+
a2 = torch.randn([3], device=DEVICE)
48+
t = FakeTransform(required_keys={a1}, output_keys={a1, a2})
4949

50-
transform(TensorDict({t1: t2}))
50+
t(TensorDict({a1: a2}))
5151

5252
with raises(ValueError):
53-
transform(TensorDict({t2: t1}))
53+
t(TensorDict({a2: a1}))
5454

5555
with raises(ValueError):
56-
transform(TensorDict({}))
56+
t(TensorDict({}))
5757

5858
with raises(ValueError):
59-
transform(TensorDict({t1: t2, t2: t1}))
59+
t(TensorDict({a1: a2, a2: a1}))
6060

6161

62-
def test_compose_keys_match():
62+
def test_compose_checks_keys():
6363
"""
6464
Tests that the composition of ``Transform``s checks that the inner transform's `output_keys`
6565
match with the outer transform's `required_keys`.
6666
"""
6767

68-
t1 = torch.randn([2], device=DEVICE)
69-
t2 = torch.randn([3], device=DEVICE)
70-
transform1 = FakeTransform({t1}, {t1, t2})
71-
transform2 = FakeTransform({t2}, {t1})
68+
a1 = torch.randn([2], device=DEVICE)
69+
a2 = torch.randn([3], device=DEVICE)
70+
t1 = FakeTransform(required_keys={a1}, output_keys={a1, a2})
71+
t2 = FakeTransform(required_keys={a2}, output_keys={a1})
7272

73-
transform1 << transform2
73+
t1 << t2
7474

7575
with raises(ValueError):
76-
transform2 << transform1
76+
t2 << t1
7777

7878

79-
def test_conjunct_required_keys():
79+
def test_conjunct_checks_required_keys():
8080
"""
8181
Tests that the conjunction of ``Transform``s checks that the provided transforms all have the
8282
same `required_keys`.
8383
"""
8484

85-
t1 = torch.randn([2], device=DEVICE)
86-
t2 = torch.randn([3], device=DEVICE)
85+
a1 = torch.randn([2], device=DEVICE)
86+
a2 = torch.randn([3], device=DEVICE)
8787

88-
transform1 = FakeTransform({t1}, set())
89-
transform2 = FakeTransform({t1}, set())
90-
transform3 = FakeTransform({t2}, set())
88+
t1 = FakeTransform(required_keys={a1}, output_keys=set())
89+
t2 = FakeTransform(required_keys={a1}, output_keys=set())
90+
t3 = FakeTransform(required_keys={a2}, output_keys=set())
9191

92-
transform1 | transform2
92+
t1 | t2
9393

9494
with raises(ValueError):
95-
transform2 | transform3
95+
t2 | t3
9696

9797
with raises(ValueError):
98-
transform1 | transform2 | transform3
98+
t1 | t2 | t3
9999

100100

101-
def test_conjunct_wrong_output_keys():
101+
def test_conjunct_checks_output_keys():
102102
"""
103103
Tests that the conjunction of ``Transform``s checks that the transforms `output_keys` are
104104
disjoint.
105105
"""
106106

107-
t1 = torch.randn([2], device=DEVICE)
108-
t2 = torch.randn([3], device=DEVICE)
107+
a1 = torch.randn([2], device=DEVICE)
108+
a2 = torch.randn([3], device=DEVICE)
109109

110-
transform1 = FakeTransform(set(), {t1, t2})
111-
transform2 = FakeTransform(set(), {t1})
112-
transform3 = FakeTransform(set(), {t2})
110+
t1 = FakeTransform(required_keys=set(), output_keys={a1, a2})
111+
t2 = FakeTransform(required_keys=set(), output_keys={a1})
112+
t3 = FakeTransform(required_keys=set(), output_keys={a2})
113113

114-
transform2 | transform3
114+
t2 | t3
115115

116116
with raises(ValueError):
117-
transform1 | transform3
117+
t1 | t3
118118

119119
with raises(ValueError):
120-
transform1 | transform2 | transform3
120+
t1 | t2 | t3
121121

122122

123-
def test_conjunction_empty_transforms():
123+
def test_empty_conjunction():
124124
"""
125125
Tests that it is possible to take the conjunction of no transform. This should return an empty
126126
dictionary.
@@ -137,7 +137,7 @@ def test_str():
137137
conjunctions.
138138
"""
139139

140-
t = FakeTransform(set(), set())
140+
t = FakeTransform(required_keys=set(), output_keys=set())
141141
transform = (t | t << t << t | t) << t << (t | t)
142142

143143
assert str(transform) == "(T | T ∘ T ∘ T | T) ∘ T ∘ (T | T)"

tests/unit/autojac/_transform/test_diagonalize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ._dict_assertions import assert_tensor_dicts_are_close
77

88

9-
def test_diagonalize_single_input():
9+
def test_single_input():
1010
"""Tests that the Diagonalize transform works when given a single input."""
1111

1212
key = torch.tensor([1.0, 2.0, 3.0], device=DEVICE)
@@ -23,7 +23,7 @@ def test_diagonalize_single_input():
2323
assert_tensor_dicts_are_close(output, expected_output)
2424

2525

26-
def test_diagonalize_multiple_inputs():
26+
def test_multiple_inputs():
2727
"""Tests that the Diagonalize transform works when given multiple inputs."""
2828

2929
key1 = torch.tensor([[1.0, 2.0], [4.0, 5.0]], device=DEVICE)
@@ -82,7 +82,7 @@ def test_diagonalize_multiple_inputs():
8282
assert_tensor_dicts_are_close(output, expected_output)
8383

8484

85-
def test_diagonalize_permute_order():
85+
def test_permute_order():
8686
"""
8787
Tests that the Diagonalize transform outputs a permuted mapping when its keys are permuted.
8888
"""

tests/unit/autojac/_transform/test_grad.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_retain_graph():
8585

8686
def test_single_input_two_levels():
8787
"""
88-
Tests that the Grad transform works correctly for a very simple example of differentiation.
88+
Tests that the Grad transform works correctly when composed with another Grad transform.
8989
Here, the function considered is: `z = a * x1 * x2`, which is computed in 2 parts: `y = a * x1`
9090
and `z = y * x2`. We want to compute the derivative of `z` with respect to the parameter `a`, by
9191
using chain rule. This derivative should be equal to `x1 * x2`.
@@ -238,9 +238,7 @@ def test_conjunction_of_grads_is_grad():
238238
x2 = torch.tensor(6.0, device=DEVICE)
239239
a1 = torch.tensor(2.0, requires_grad=True, device=DEVICE)
240240
a2 = torch.tensor(3.0, requires_grad=True, device=DEVICE)
241-
y1 = a1 * x1
242-
y2 = a2 * x2
243-
y = torch.stack([y1, y2])
241+
y = torch.stack([a1 * x1, a2 * x2])
244242
input = Gradients({y: torch.ones_like(y)})
245243

246244
grad1 = Grad(outputs=[y], inputs=[a1], retain_graph=True)
@@ -258,10 +256,10 @@ def test_create_graph():
258256
"""Tests that the Grad transform behaves correctly when `create_graph` is set to `True`."""
259257

260258
a = torch.tensor(2.0, requires_grad=True, device=DEVICE)
261-
b = a * a
262-
input = Gradients({b: torch.ones_like(b)})
259+
y = a * a
260+
input = Gradients({y: torch.ones_like(y)})
263261

264-
grad = Grad(outputs=[b], inputs=[a], create_graph=True)
262+
grad = Grad(outputs=[y], inputs=[a], create_graph=True)
265263

266264
gradients = grad(input)
267265

tests/unit/autojac/_transform/test_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ._dict_assertions import assert_tensor_dicts_are_close
77

88

9-
def test_init_single_input():
9+
def test_single_input():
1010
"""
1111
Tests that when there is a single key to initialize, the Init transform creates a TensorDict
1212
whose value is a tensor full of ones, of the same shape as its key.
@@ -23,7 +23,7 @@ def test_init_single_input():
2323
assert_tensor_dicts_are_close(output, expected_output)
2424

2525

26-
def test_init_multiple_input():
26+
def test_multiple_inputs():
2727
"""
2828
Tests that when there are several keys to initialize, the Init transform creates a TensorDict
2929
whose values are tensors full of ones, of the same shape as their corresponding keys.

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ def test_jac_is_stack_of_grads():
3333
y2 = a2 * x
3434
input = Gradients({y1: torch.ones_like(y1), y2: torch.ones_like(y2)})
3535

36-
jac = Jac(outputs=[y1, y2], inputs=[a1, a2], chunk_size=None, retain_graph=True) << Diagonalize(
37-
[y1, y2]
38-
)
39-
grad1 = Grad(outputs=[y1], inputs=[a1, a2]) << Select([y1], [y1, y2])
40-
grad2 = Grad(outputs=[y2], inputs=[a1, a2]) << Select([y2], [y1, y2])
41-
stack_of_grads = Stack([grad1, grad2])
36+
jac = Jac(outputs=[y1, y2], inputs=[a1, a2], chunk_size=None, retain_graph=True)
37+
diag = Diagonalize([y1, y2])
38+
jac_diag = jac << diag
39+
40+
grad1 = Grad(outputs=[y1], inputs=[a1, a2])
41+
grad2 = Grad(outputs=[y2], inputs=[a1, a2])
42+
select1 = Select([y1], [y1, y2])
43+
select2 = Select([y2], [y1, y2])
44+
stack_of_grads = Stack([grad1 << select1, grad2 << select2])
4245

43-
jacobians = jac(input)
46+
jacobians = jac_diag(input)
4447
expected_jacobians = stack_of_grads(input)
4548

4649
assert_tensor_dicts_are_close(jacobians, expected_jacobians)
@@ -66,7 +69,7 @@ def test_single_differentiation():
6669
assert_tensor_dicts_are_close(output, expected_output)
6770

6871

69-
def test_multiple_differentiation_with_grad():
72+
def test_multiple_differentiations():
7073
"""
7174
Tests that we can perform multiple scalar differentiations with the conjunction of multiple Grad
7275
transforms, composed with an Init transform.
@@ -78,10 +81,12 @@ def test_multiple_differentiation_with_grad():
7881
y2 = a2 * 3.0
7982
input = EmptyTensorDict()
8083

81-
grad1 = Grad([y1], [a1]) << Select([y1], [y1, y2])
82-
grad2 = Grad([y2], [a2]) << Select([y2], [y1, y2])
84+
grad1 = Grad([y1], [a1])
85+
grad2 = Grad([y2], [a2])
86+
select1 = Select([y1], [y1, y2])
87+
select2 = Select([y2], [y1, y2])
8388
init = Init([y1, y2])
84-
transform = (grad1 | grad2) << init
89+
transform = ((grad1 << select1) | (grad2 << select2)) << init
8590

8691
output = transform(input)
8792
expected_output = {
@@ -182,7 +187,8 @@ def test_conjunction_accumulate_select():
182187
"""
183188
Tests that it is possible to conjunct an Accumulate and a Select in this order.
184189
It is not trivial since the type of the TensorDict returned by the first transform (Accumulate)
185-
is EmptyDict, which is not the type that the conjunction should return (Gradients).
190+
is EmptyDict, which is not the type that the conjunction should return (Gradients), but a
191+
subclass of it.
186192
"""
187193

188194
key = torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=DEVICE)
@@ -199,7 +205,7 @@ def test_conjunction_accumulate_select():
199205
assert_tensor_dicts_are_close(output, expected_output)
200206

201207

202-
def test_equivalence_jac_grad():
208+
def test_equivalence_jac_grads():
203209
"""
204210
Tests that differentiation in parallel using `_jac` is equivalent to sequential differentiation
205211
using several calls to `_grad` and stacking the resulting gradients.
@@ -219,18 +225,12 @@ def test_equivalence_jac_grad():
219225
outputs = [y1, y2]
220226
grad_outputs = [torch.ones_like(output) for output in outputs]
221227

222-
grad_dict_1 = Grad(
223-
outputs=[outputs[0]],
224-
inputs=inputs,
225-
retain_graph=True,
226-
)(Gradients({outputs[0]: grad_outputs[0]}))
228+
grad1 = Grad(outputs=[outputs[0]], inputs=inputs, retain_graph=True)
229+
grad_dict_1 = grad1(Gradients({outputs[0]: grad_outputs[0]}))
227230
grad_1_A, grad_1_b, grad_1_c = grad_dict_1[A], grad_dict_1[b], grad_dict_1[c]
228231

229-
grad_dict_2 = Grad(
230-
outputs=[outputs[1]],
231-
inputs=inputs,
232-
retain_graph=True,
233-
)(Gradients({outputs[1]: grad_outputs[1]}))
232+
grad2 = Grad(outputs=[outputs[1]], inputs=inputs, retain_graph=True)
233+
grad_dict_2 = grad2(Gradients({outputs[1]: grad_outputs[1]}))
234234
grad_2_A, grad_2_b, grad_2_c = grad_dict_2[A], grad_dict_2[b], grad_dict_2[c]
235235

236236
n_outputs = len(outputs)
@@ -240,11 +240,10 @@ def test_equivalence_jac_grad():
240240
for i, grad_output in enumerate(grad_outputs):
241241
batched_grad_outputs[i][i] = grad_output
242242

243-
jac_dict = Jac(
244-
outputs=outputs,
245-
inputs=inputs,
246-
chunk_size=None,
247-
)(Jacobians({outputs[0]: batched_grad_outputs[0], outputs[1]: batched_grad_outputs[1]}))
243+
jac = Jac(outputs=outputs, inputs=inputs, chunk_size=None)
244+
jac_dict = jac(
245+
Jacobians({outputs[0]: batched_grad_outputs[0], outputs[1]: batched_grad_outputs[1]})
246+
)
248247
jac_A, jac_b, jac_c = jac_dict[A], jac_dict[b], jac_dict[c]
249248

250249
assert_close(jac_A, torch.stack([grad_1_A, grad_2_A]))

0 commit comments

Comments
 (0)