Skip to content

Commit a352cc9

Browse files
ValerianReyAustenManPierreQuinton
authored
Fix _get_descendant_accumulate_grads (#217)
* Prevent nodes from being traversed twice in _get_descendant_accumulate_grads * Use queue for nodes_to_traverse * Rename variables in _get_descendant_accumulate_grads * Add changelog entry --------- Co-Authored-By: austen260 <137848202+austen260@users.noreply.github.com> Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
1 parent 34deb43 commit a352cc9

2 files changed

Lines changed: 19 additions & 14 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- Improved the performance of the graph traversal function called by `backward` and `mtl_backward`
14+
to find the tensors with respect to which differentiation should be done. It now visits every node
15+
at most once.
16+
1117
## [0.3.0] - 2024-12-10
1218

1319
### Added

src/torchjd/autojac/_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import deque
12
from typing import Iterable, Sequence
23

34
from torch import Tensor
@@ -67,24 +68,22 @@ def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]
6768
:param excluded_nodes: Nodes excluded from the graph traversal.
6869
"""
6970

71+
excluded_nodes = set(excluded_nodes) # Re-instantiate set to avoid modifying input
7072
result = set()
71-
nodes_to_traverse = [node for node in roots if node not in excluded_nodes]
73+
nodes_to_traverse = deque(roots - excluded_nodes)
7274

73-
# This implementation more or less follows what is advised
74-
# [here](https://discuss.pytorch.org/t/how-to-access-the-computational-graph/112887), but it is
75-
# not necessarily robust to future changes, and it's not guaranteed to work.
76-
# See [this](https://discuss.pytorch.org/t/autograd-graph-traversal/213658) for another question
77-
# about how to implement this.
75+
# This implementation more or less follows what is advised in
76+
# https://discuss.pytorch.org/t/autograd-graph-traversal/213658 and what was suggested in
77+
# https://github.com/TorchJD/torchjd/issues/216.
7878
while nodes_to_traverse:
79-
current_node = nodes_to_traverse.pop()
79+
node = nodes_to_traverse.popleft() # Breadth-first
8080

81-
if current_node.__class__.__name__ == "AccumulateGrad":
82-
result.add(current_node)
81+
if node.__class__.__name__ == "AccumulateGrad":
82+
result.add(node)
8383

84-
nodes_to_traverse += [
85-
child[0]
86-
for child in current_node.next_functions
87-
if child[0] is not None and child[0] not in excluded_nodes
88-
]
84+
for child, _ in node.next_functions:
85+
if child is not None and child not in excluded_nodes:
86+
nodes_to_traverse.append(child) # Append to the right
87+
excluded_nodes.add(child)
8988

9089
return result

0 commit comments

Comments
 (0)