Skip to content

Commit 66c99db

Browse files
committed
Refactor GradVac as a GramianWeightedAggregator with GradVacWeighting
GradVac only needs gradient norms and dot products, which are fully determined by the Gramian. This makes GradVac compatible with the autogram path. Also removes the grouping parameters (group_type, encoder, shared_params) from GradVac, and exports GradVacWeighting publicly.
1 parent 8d1f6e7 commit 66c99db

5 files changed

Lines changed: 185 additions & 277 deletions

File tree

docs/source/docs/aggregation/gradvac.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ GradVac
77
:members:
88
:undoc-members:
99
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.GradVacWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

src/torchjd/aggregation/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from ._dualproj import DualProj, DualProjWeighting
6767
from ._flattening import Flattening
6868
from ._graddrop import GradDrop
69-
from ._gradvac import GradVac
69+
from ._gradvac import GradVac, GradVacWeighting
7070
from ._imtl_g import IMTLG, IMTLGWeighting
7171
from ._krum import Krum, KrumWeighting
7272
from ._mean import Mean, MeanWeighting
@@ -94,6 +94,7 @@
9494
"GeneralizedWeighting",
9595
"GradDrop",
9696
"GradVac",
97+
"GradVacWeighting",
9798
"IMTLG",
9899
"IMTLGWeighting",
99100
"Krum",
Lines changed: 132 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,124 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable
4-
from typing import Literal, cast
3+
from typing import cast
54

65
import torch
7-
import torch.nn as nn
86
from torch import Tensor
97

10-
from torchjd._linalg import Matrix
8+
from torchjd._linalg import PSDMatrix
119

12-
from ._aggregator_bases import Aggregator
10+
from ._aggregator_bases import GramianWeightedAggregator
1311
from ._utils.non_differentiable import raise_non_differentiable_error
12+
from ._weighting_bases import Weighting
1413

1514

16-
def _all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]:
17-
"""
18-
Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate
19-
``encoder.modules()`` and append the total number of elements in each module that has no child
20-
submodules and registers at least one parameter.
21-
"""
22-
23-
return tuple(
24-
sum(w.numel() for w in module.parameters())
25-
for module in encoder.modules()
26-
if len(list(module.children())) == 0 and next(module.parameters(), None) is not None
27-
)
28-
29-
30-
def _all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]:
31-
"""One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout)."""
32-
33-
return tuple(p.numel() for p in shared_params)
34-
35-
36-
class GradVac(Aggregator):
15+
class GradVac(GramianWeightedAggregator):
3716
r"""
38-
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine
39-
(GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in
40-
Massively Multilingual Models (ICLR 2021 Spotlight)
17+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
18+
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
19+
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
4120
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.
4221
43-
The input matrix is a Jacobian :math:`J \in \mathbb{R}^{m \times n}` whose rows are per-task
44-
gradients. For each task :math:`i` and each parameter block :math:`k`, the order in which other
45-
tasks :math:`j` are visited is drawn at random (independently for each :math:`k`); for each pair
46-
:math:`(i, j)` on block :math:`k`, the cosine correlation :math:`\phi_{ijk}` between the
22+
For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at
23+
random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the
4724
(possibly already modified) gradient of task :math:`i` and the original gradient of task
48-
:math:`j` on that block is compared to an EMA target :math:`\hat{\phi}_{ijk}`. When
49-
:math:`\phi_{ijk} < \hat{\phi}_{ijk}`, a closed-form correction adds a scaled copy of
50-
:math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
51-
:math:`\hat{\phi}_{ijk} \leftarrow (1-\beta)\hat{\phi}_{ijk} + \beta \phi_{ijk}`. The aggregated
25+
:math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When
26+
:math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of
27+
:math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
28+
:math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated
5229
vector is the sum of the modified rows.
5330
5431
This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
55-
the number of tasks, parameter dimension, grouping, device, or dtype changes.
56-
57-
**Parameter granularity** is selected by ``group_type`` (default ``"whole_model"``). It defines
58-
how each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets
59-
:math:`\hat{\phi}_{ijk}` are computed **per block** rather than only globally:
60-
61-
* ``"whole_model"``: the full row of length :math:`n` is a single block. Cosine similarity is
62-
taken between entire task gradients. Do not pass ``encoder`` or ``shared_params``.
63-
* ``"all_layer"``: one block per leaf ``nn.Module`` under ``encoder`` that holds parameters
64-
(same rule as iterating ``encoder.modules()`` and selecting leaves with parameters). Pass
65-
``encoder``; ``shared_params`` must be omitted.
66-
* ``"all_matrix"``: one block per tensor in ``shared_params``, in iteration order. That order
67-
must match how Jacobian columns are laid out for those shared parameters. Pass
68-
``shared_params``; ``encoder`` must be omitted.
69-
70-
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign the
71-
:attr:`beta` attribute between steps to tune the EMA update.
72-
:param group_type: Granularity of parameter grouping; see **Parameter granularity** above.
73-
:param encoder: Module whose subtree defines ``all_layer`` blocks when
74-
``group_type == "all_layer"``.
75-
:param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and
76-
order when ``group_type == "all_matrix"``. It is materialized once at construction.
32+
the number of tasks or dtype changes.
33+
34+
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign
35+
the :attr:`beta` attribute between steps to tune the EMA update.
7736
:param eps: Small non-negative constant added to denominators when computing cosines and the
7837
vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or
7938
assign the :attr:`eps` attribute between steps to tune numerical behavior.
8039
8140
.. note::
82-
GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner
83-
products, not only a Gram matrix. Only the autojac path is supported.
41+
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
42+
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
43+
you need reproducibility.
8444
8545
.. note::
86-
For each task :math:`i` and block :math:`k`, the order of other tasks :math:`j` is shuffled
87-
independently using the global PyTorch RNG (``torch.randperm``). Seed it with
88-
``torch.manual_seed`` if you need reproducibility.
46+
To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the
47+
Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage
48+
example for details.
49+
"""
50+
51+
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
52+
weighting = GradVacWeighting(beta=beta, eps=eps)
53+
super().__init__(weighting)
54+
self._gradvac_weighting = weighting
55+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
56+
57+
@property
58+
def beta(self) -> float:
59+
"""EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``)."""
60+
61+
return self._gradvac_weighting.beta
62+
63+
@beta.setter
64+
def beta(self, value: float) -> None:
65+
self._gradvac_weighting.beta = value
66+
67+
@property
68+
def eps(self) -> float:
69+
"""Small non-negative constant added to denominators for numerical stability."""
70+
71+
return self._gradvac_weighting.eps
72+
73+
@eps.setter
74+
def eps(self, value: float) -> None:
75+
self._gradvac_weighting.eps = value
76+
77+
def reset(self) -> None:
78+
"""Clears EMA state so the next forward starts from zero targets."""
79+
80+
self._gradvac_weighting.reset()
81+
82+
def __repr__(self) -> str:
83+
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"
84+
85+
86+
class GradVacWeighting(Weighting[PSDMatrix]):
87+
r"""
88+
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
89+
:class:`~torchjd.aggregation.GradVac`.
90+
91+
All required quantities (gradient norms, cosine similarities, and their updates after the
92+
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
93+
If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then:
94+
95+
.. math::
96+
97+
\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad
98+
g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}
99+
100+
where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w
101+
g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow
102+
immediately.
103+
104+
This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
105+
the number of tasks or dtype changes.
106+
107+
:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``).
108+
:param eps: Small non-negative constant added to denominators (default ``1e-8``).
89109
"""
90110

91-
def __init__(
92-
self,
93-
beta: float = 0.5,
94-
group_type: Literal["whole_model", "all_layer", "all_matrix"] = "whole_model",
95-
encoder: nn.Module | None = None,
96-
shared_params: Iterable[Tensor] | None = None,
97-
eps: float = 1e-8,
98-
) -> None:
111+
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
99112
super().__init__()
100113
if not (0.0 <= beta <= 1.0):
101114
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
102-
params_tuple: tuple[Tensor, ...] = ()
103-
fixed_block_sizes: tuple[int, ...] | None
104-
if group_type == "whole_model":
105-
if encoder is not None:
106-
raise ValueError(
107-
'Parameter `encoder` must be None when `group_type == "whole_model"`.'
108-
)
109-
if shared_params is not None:
110-
raise ValueError(
111-
'Parameter `shared_params` must be None when `group_type == "whole_model"`.'
112-
)
113-
fixed_block_sizes = None
114-
elif group_type == "all_layer":
115-
if encoder is None:
116-
raise ValueError(
117-
'Parameter `encoder` is required when `group_type == "all_layer"`.'
118-
)
119-
if shared_params is not None:
120-
raise ValueError(
121-
'Parameter `shared_params` must be None when `group_type == "all_layer"`.'
122-
)
123-
fixed_block_sizes = _all_layer_group_sizes(encoder)
124-
if sum(fixed_block_sizes) == 0:
125-
raise ValueError("Parameter `encoder` has no parameters in any leaf module.")
126-
else:
127-
if shared_params is None:
128-
raise ValueError(
129-
'Parameter `shared_params` is required when `group_type == "all_matrix"`.'
130-
)
131-
if encoder is not None:
132-
raise ValueError(
133-
'Parameter `encoder` must be None when `group_type == "all_matrix"`.'
134-
)
135-
params_tuple = tuple(shared_params)
136-
if len(params_tuple) == 0:
137-
raise ValueError(
138-
'Parameter `shared_params` must be non-empty when `group_type == "all_matrix"`.'
139-
)
140-
fixed_block_sizes = _all_matrix_group_sizes(params_tuple)
141-
142115
if eps < 0.0:
143116
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")
144117

145118
self._beta = beta
146-
self._group_type = group_type
147-
self._encoder = encoder
148-
self._shared_params_len = len(params_tuple)
149-
self._fixed_block_sizes = fixed_block_sizes
150119
self._eps = eps
151-
152120
self._phi_t: Tensor | None = None
153-
self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None
154-
155-
self.register_full_backward_pre_hook(raise_non_differentiable_error)
121+
self._state_key: tuple[int, torch.dtype] | None = None
156122

157123
@property
158124
def beta(self) -> float:
@@ -184,82 +150,56 @@ def reset(self) -> None:
184150
self._phi_t = None
185151
self._state_key = None
186152

187-
def __repr__(self) -> str:
188-
enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)"
189-
sp = "None" if self._group_type != "all_matrix" else f"n_params={self._shared_params_len}"
190-
return (
191-
f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, "
192-
f"encoder={enc}, shared_params={sp}, eps={self._eps!r})"
193-
)
194-
195-
def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]:
196-
if self._group_type == "whole_model":
197-
return (n,)
198-
sizes = cast(tuple[int, ...], self._fixed_block_sizes)
199-
if sum(sizes) != n:
200-
raise ValueError(
201-
"The Jacobian width `n` must equal the sum of block sizes implied by "
202-
f"`encoder` or `shared_params` for this `group_type`. Found n={n}, "
203-
f"sum(block_sizes)={sum(sizes)}.",
204-
)
205-
return sizes
206-
207-
def _ensure_state(
208-
self,
209-
m: int,
210-
n: int,
211-
sizes: tuple[int, ...],
212-
device: torch.device,
213-
dtype: torch.dtype,
214-
) -> None:
215-
key = (m, n, sizes, device, dtype)
216-
num_groups = len(sizes)
217-
if self._state_key != key or self._phi_t is None:
218-
self._phi_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
219-
self._state_key = key
153+
def forward(self, gramian: PSDMatrix, /) -> Tensor:
154+
device = gramian.device
155+
dtype = gramian.dtype
156+
cpu = torch.device("cpu")
220157

221-
def forward(self, matrix: Matrix, /) -> Tensor:
222-
grads = matrix
223-
m, n = grads.shape
224-
if m == 0 or n == 0:
225-
return torch.zeros(n, dtype=grads.dtype, device=grads.device)
158+
G = cast(PSDMatrix, gramian.to(device=cpu))
159+
m = G.shape[0]
226160

227-
sizes = self._resolve_segment_sizes(n)
228-
device = grads.device
229-
dtype = grads.dtype
230-
self._ensure_state(m, n, sizes, device, dtype)
161+
self._ensure_state(m, dtype)
231162
phi_t = cast(Tensor, self._phi_t)
232-
beta = self.beta
233-
eps = self.eps
234163

235-
pc_grads = grads.clone()
236-
offsets = [0]
237-
for s in sizes:
238-
offsets.append(offsets[-1] + s)
164+
beta = self._beta
165+
eps = self._eps
166+
167+
# C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients).
168+
# Initially each modified gradient equals the original, so C = I.
169+
C = torch.eye(m, device=cpu, dtype=dtype)
239170

240171
for i in range(m):
172+
# Dot products of g_i^PC with every original g_j, shape (m,).
173+
cG = C[i] @ G
174+
241175
others = [j for j in range(m) if j != i]
242-
for k in range(len(sizes)):
243-
perm = torch.randperm(len(others))
244-
shuffled_js = [others[idx] for idx in perm.tolist()]
245-
beg, end = offsets[k], offsets[k + 1]
246-
for j in shuffled_js:
247-
slice_i = pc_grads[i, beg:end]
248-
slice_j = grads[j, beg:end]
249-
250-
norm_i = slice_i.norm()
251-
norm_j = slice_j.norm()
252-
denom = norm_i * norm_j + eps
253-
phi_ijk = slice_i.dot(slice_j) / denom
254-
255-
phi_hat = phi_t[i, j, k]
256-
if phi_ijk < phi_hat:
257-
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
258-
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
259-
denom_w = norm_j * sqrt_1_hat2 + eps
260-
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
261-
pc_grads[i, beg:end] = slice_i + slice_j * w
262-
263-
phi_t[i, j, k] = (1.0 - beta) * phi_hat + beta * phi_ijk
264-
265-
return pc_grads.sum(dim=0)
176+
perm = torch.randperm(len(others))
177+
shuffled_js = [others[idx] for idx in perm.tolist()]
178+
179+
for j in shuffled_js:
180+
dot_ij = cG[j]
181+
norm_i_sq = (cG * C[i]).sum()
182+
norm_i = norm_i_sq.clamp(min=0.0).sqrt()
183+
norm_j = G[j, j].clamp(min=0.0).sqrt()
184+
denom = norm_i * norm_j + eps
185+
phi_ijk = dot_ij / denom
186+
187+
phi_hat = phi_t[i, j]
188+
if phi_ijk < phi_hat:
189+
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
190+
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
191+
denom_w = norm_j * sqrt_1_hat2 + eps
192+
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
193+
C[i, j] = C[i, j] + w
194+
cG = cG + w * G[j]
195+
196+
phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk
197+
198+
weights = C.sum(dim=0)
199+
return weights.to(device)
200+
201+
def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
202+
key = (m, dtype)
203+
if self._state_key != key or self._phi_t is None:
204+
self._phi_t = torch.zeros(m, m, dtype=dtype)
205+
self._state_key = key

0 commit comments

Comments
 (0)