-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_aligned_mtl.py
More file actions
115 lines (90 loc) · 4.37 KB
/
_aligned_mtl.py
File metadata and controls
115 lines (90 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Partly adapted from https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned — MIT License, Copyright (c) 2022 Samsung.
# See NOTICES for the full license text.
from typing import Literal, TypeAlias
import torch
from torch import Tensor
from torchjd._linalg import PSDMatrix
from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]
class AlignedMTL(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Independent Component Alignment for Multi-Task Learning
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.
:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).
.. note::
This implementation was adapted from the official implementation of SamsungLabs/MTL,
which is not available anymore at the time of writing.
"""
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> None:
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
)
def __str__(self) -> str:
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
class AlignedMTLWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.AlignedMTL`.
:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).
"""
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
def forward(self, gramian: PSDMatrix, /) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian, self._scale_mode)
alpha = B @ w
return alpha
@staticmethod
def _compute_balance_transformation(
M: Tensor,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)
if rank == 0:
identity = torch.eye(len(M), dtype=M.dtype, device=M.device)
return identity
order = torch.argsort(lambda_, dim=-1, descending=True)
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]
sigma_inv = torch.diag(1 / lambda_.sqrt())
if scale_mode == "min":
scale = lambda_[-1]
elif scale_mode == "median":
scale = torch.median(lambda_)
elif scale_mode == "rmse":
scale = lambda_.mean()
else:
raise ValueError(
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.",
)
B = scale.sqrt() * V @ sigma_inv @ V.T
return B