-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_nash_mtl.py
More file actions
211 lines (168 loc) · 8.06 KB
/
_nash_mtl.py
File metadata and controls
211 lines (168 loc) · 8.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.
from typing import cast
from torchjd._linalg import Matrix
from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Weighting
check_dependencies_are_installed(["cvxpy", "ecos"])
import cvxpy as cp
import numpy as np
import torch
from cvxpy import Expression, SolverError
from torch import Tensor
from ._aggregator_bases import WeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
class NashMTL(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of
`Multi-Task Learning as a Bargaining Game <https://arxiv.org/pdf/2202.01017.pdf>`_.
:param n_tasks: The number of tasks, corresponding to the number of rows in the provided
matrices.
:param max_norm: Maximum value of the norm of :math:`J^T w`.
:param update_weights_every: A parameter determining how often the actual weighting should be
performed. A larger value means that the same weights will be re-used for more calls to the
aggregator.
:param optim_niter: The number of iterations of the underlying optimization process.
.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[nash_mtl]"``.
.. warning::
This implementation was adapted from the `official implementation
<https://github.com/AvivNavon/nash-mtl/tree/main>`_, which has some flaws. Use with caution.
.. warning::
This aggregator is stateful. Its output will thus depend not only on the input matrix, but
also on its state. It thus depends on previously seen matrices. It should be reset between
experiments.
"""
def __init__(
self,
n_tasks: int,
max_norm: float = 1.0,
update_weights_every: int = 1,
optim_niter: int = 20,
) -> None:
super().__init__(
weighting=_NashMTLWeighting(
n_tasks=n_tasks,
max_norm=max_norm,
update_weights_every=update_weights_every,
optim_niter=optim_niter,
),
)
self._n_tasks = n_tasks
self._max_norm = max_norm
self._update_weights_every = update_weights_every
self._optim_niter = optim_niter
# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
def reset(self) -> None:
"""Resets the internal state of the algorithm."""
cast(_NashMTLWeighting, self.weighting).reset()
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(n_tasks={self._n_tasks}, max_norm={self._max_norm}, "
f"update_weights_every={self._update_weights_every}, optim_niter={self._optim_niter})"
)
class _NashMTLWeighting(Weighting[Matrix]):
"""
:class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision
of Algorithm 1 of `Multi-Task Learning as a Bargaining Game
<https://arxiv.org/pdf/2202.01017.pdf>`_.
:param n_tasks: The number of tasks, corresponding to the number of rows in the provided
matrices.
:param max_norm: Maximum value of the norm of :math:`J^T w`.
:param update_weights_every: A parameter determining how often the actual weighting should be
performed. A larger value means that the same weights will be re-used for more calls to the
weighting.
:param optim_niter: The number of iterations of the underlying optimization process.
"""
def __init__(
self,
n_tasks: int,
max_norm: float,
update_weights_every: int,
optim_niter: int,
) -> None:
super().__init__()
self.n_tasks = n_tasks
self.optim_niter = optim_niter
self.update_weights_every = update_weights_every
self.max_norm = max_norm
self.prvs_alpha_param = None
self.normalization_factor = np.ones((1,))
self.init_gtg = np.eye(self.n_tasks)
self.step = 0.0
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)
def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool:
return bool(
(self.alpha_param.value is None)
or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6),
)
def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray:
self.G_param.value = gtg
self.normalization_factor_param.value = self.normalization_factor
alpha_t = self.prvs_alpha
for _ in range(self.optim_niter):
self.alpha_param.value = alpha_t
self.prvs_alpha_param.value = alpha_t
try:
self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
except (SolverError, ValueError):
# On macOS, SolverError can happen with: Solver 'ECOS' failed.
# No idea why. The corresponding matrix is of shape [9, 11] with rank 5.
# ValueError happens with for example matrix [[0., 0.], [0., 1.]].
# Maybe other exceptions can happen in other cases.
self.alpha_param.value = self.prvs_alpha_param.value
if self._stop_criteria(gtg, alpha_t):
break
alpha_t = self.alpha_param.value
if alpha_t is not None:
self.prvs_alpha = alpha_t
return self.prvs_alpha
def _calc_phi_alpha_linearization(self) -> Expression:
G_prvs_alpha = self.G_param @ self.prvs_alpha_param
prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param
phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param)
return phi_alpha
def _init_optim_problem(self) -> None:
self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True)
self.prvs_alpha_param = cp.Parameter(shape=(self.n_tasks,), value=self.prvs_alpha)
self.G_param = cp.Parameter(shape=(self.n_tasks, self.n_tasks), value=self.init_gtg)
self.normalization_factor_param = cp.Parameter(shape=(1,), value=np.array([1.0]))
self.phi_alpha = self._calc_phi_alpha_linearization()
G_alpha = self.G_param @ self.alpha_param
constraint = [
-cp.log(a * self.normalization_factor_param) - cp.log(G_a) <= 0
for a, G_a in zip(self.alpha_param, G_alpha, strict=True)
]
obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param)
self.prob = cp.Problem(obj, constraint)
def forward(self, matrix: Tensor, /) -> Tensor:
if self.step == 0:
self._init_optim_problem()
if (self.step % self.update_weights_every) == 0:
self.step += 1
G = matrix
GTG = torch.mm(G, G.t())
self.normalization_factor = torch.norm(GTG).detach().cpu().numpy().reshape((1,))
GTG = GTG / self.normalization_factor.item()
alpha = self._solve_optimization(GTG.cpu().detach().numpy())
else:
self.step += 1
alpha = self.prvs_alpha
alpha = torch.from_numpy(alpha).to(device=matrix.device, dtype=matrix.dtype)
if self.max_norm > 0:
norm = torch.linalg.norm(alpha @ matrix)
if norm > self.max_norm:
alpha = (alpha / norm) * self.max_norm
return alpha
def reset(self) -> None:
"""Resets the internal state of the algorithm."""
self.prvs_alpha_param = None
self.normalization_factor = np.ones((1,))
self.init_gtg = np.eye(self.n_tasks)
self.step = 0.0
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)