Skip to content

Commit 8d1f6e7

Browse files
committed
refactor(gradvac): literal group types, eps/beta rules, and plotter UX
- Use group_type "whole_model" | "all_layer" | "all_matrix" instead of 0/1/2 - Remove DEFAULT_GRADVAC_EPS from the public API; keep default 1e-8; allow eps=0 - Validate beta via setter; tighten GradVac repr/str expectations - Fix all_layer leaf sizing via children() and parameters() instead of private fields - Trim redundant GradVac.rst prose; align docs with the new API - Tests: GradVac cases, value regression with torch.manual_seed for GradVac - Plotter: factory dict + fresh aggregator instances per update; legend from selected keys; MathJax labels and live angle/length readouts in the sidebar This commit includes GradVac implementation with Aggregator class.
1 parent 1030d57 commit 8d1f6e7

7 files changed

Lines changed: 269 additions & 187 deletions

File tree

docs/source/docs/aggregation/gradvac.rst

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@
33
GradVac
44
=======
55

6-
.. autodata:: torchjd.aggregation.DEFAULT_GRADVAC_EPS
7-
8-
The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the
9-
per-block cosine statistics in GradVac:
10-
11-
* ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and
12-
``shared_params``.
13-
* ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under
14-
``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation).
15-
* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use
16-
the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass
17-
to a shared-gradient helper).
18-
196
.. autoclass:: torchjd.aggregation.GradVac
207
:members:
218
:undoc-members:

src/torchjd/aggregation/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from ._dualproj import DualProj, DualProjWeighting
6767
from ._flattening import Flattening
6868
from ._graddrop import GradDrop
69-
from ._gradvac import DEFAULT_GRADVAC_EPS, GradVac
69+
from ._gradvac import GradVac
7070
from ._imtl_g import IMTLG, IMTLGWeighting
7171
from ._krum import Krum, KrumWeighting
7272
from ._mean import Mean, MeanWeighting
@@ -88,7 +88,6 @@
8888
"ConFIG",
8989
"Constant",
9090
"ConstantWeighting",
91-
"DEFAULT_GRADVAC_EPS",
9291
"DualProj",
9392
"DualProjWeighting",
9493
"Flattening",

src/torchjd/aggregation/_gradvac.py

Lines changed: 104 additions & 90 deletions
Large diffs are not rendered by default.

tests/plots/_utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Callable
2+
13
import numpy as np
24
import torch
35
from plotly import graph_objects as go
@@ -7,14 +9,22 @@
79

810

911
class Plotter:
10-
def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None:
11-
self.aggregators = aggregators
12+
def __init__(
13+
self,
14+
aggregator_factories: dict[str, Callable[[], Aggregator]],
15+
selected_keys: list[str],
16+
matrix: torch.Tensor,
17+
seed: int = 0,
18+
) -> None:
19+
self._aggregator_factories = aggregator_factories
20+
self.selected_keys = selected_keys
1221
self.matrix = matrix
1322
self.seed = seed
1423

1524
def make_fig(self) -> Figure:
1625
torch.random.manual_seed(self.seed)
17-
results = [agg(self.matrix) for agg in self.aggregators]
26+
aggregators = [self._aggregator_factories[key]() for key in self.selected_keys]
27+
results = [agg(self.matrix) for agg in aggregators]
1828

1929
fig = go.Figure()
2030

@@ -23,14 +33,19 @@ def make_fig(self) -> Figure:
2333
fig.add_trace(cone)
2434

2535
for i in range(len(self.matrix)):
26-
scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}")
36+
scatter = make_vector_scatter(
37+
self.matrix[i],
38+
"blue",
39+
f"g{i + 1}",
40+
textposition="top right",
41+
)
2742
fig.add_trace(scatter)
2843

2944
for i in range(len(results)):
3045
scatter = make_vector_scatter(
3146
results[i],
3247
"black",
33-
str(self.aggregators[i]),
48+
self.selected_keys[i],
3449
showlegend=True,
3550
dash=True,
3651
)

tests/plots/interactive_plotter.py

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
import webbrowser
4+
from collections.abc import Callable
45
from threading import Timer
56

67
import numpy as np
@@ -12,11 +13,13 @@
1213
from torchjd.aggregation import (
1314
IMTLG,
1415
MGDA,
16+
Aggregator,
1517
AlignedMTL,
1618
CAGrad,
1719
ConFIG,
1820
DualProj,
1921
GradDrop,
22+
GradVac,
2023
Mean,
2124
NashMTL,
2225
PCGrad,
@@ -30,6 +33,14 @@
3033
MAX_LENGTH = 25.0
3134

3235

36+
def _format_angle_display(angle: float) -> str:
37+
return f"{angle:.4f} rad ({np.degrees(angle):.1f}°)"
38+
39+
40+
def _format_length_display(r: float) -> str:
41+
return f"{r:.4f}"
42+
43+
3344
def main() -> None:
3445
log = logging.getLogger("werkzeug")
3546
log.setLevel(logging.CRITICAL)
@@ -42,26 +53,30 @@ def main() -> None:
4253
],
4354
)
4455

45-
aggregators = [
46-
AlignedMTL(),
47-
CAGrad(c=0.5),
48-
ConFIG(),
49-
DualProj(),
50-
GradDrop(),
51-
IMTLG(),
52-
Mean(),
53-
MGDA(),
54-
NashMTL(n_tasks=matrix.shape[0]),
55-
PCGrad(),
56-
Random(),
57-
Sum(),
58-
TrimmedMean(trim_number=1),
59-
UPGrad(),
60-
]
61-
62-
aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators}
63-
64-
plotter = Plotter([], matrix)
56+
n_tasks = matrix.shape[0]
57+
aggregator_factories: dict[str, Callable[[], Aggregator]] = {
58+
"AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"),
59+
"AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"),
60+
"AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"),
61+
str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5),
62+
str(ConFIG()): lambda: ConFIG(),
63+
str(DualProj()): lambda: DualProj(),
64+
str(GradDrop()): lambda: GradDrop(),
65+
str(GradVac()): lambda: GradVac(),
66+
str(IMTLG()): lambda: IMTLG(),
67+
str(Mean()): lambda: Mean(),
68+
str(MGDA()): lambda: MGDA(),
69+
str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks),
70+
str(PCGrad()): lambda: PCGrad(),
71+
str(Random()): lambda: Random(),
72+
str(Sum()): lambda: Sum(),
73+
str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1),
74+
str(UPGrad()): lambda: UPGrad(),
75+
}
76+
77+
aggregator_strings = list(aggregator_factories.keys())
78+
79+
plotter = Plotter(aggregator_factories, [], matrix)
6580

6681
app = Dash(__name__)
6782

@@ -96,7 +111,6 @@ def main() -> None:
96111
gradient_slider_inputs.append(Input(angle_input, "value"))
97112
gradient_slider_inputs.append(Input(r_input, "value"))
98113

99-
aggregator_strings = [str(aggregator) for aggregator in aggregators]
100114
checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist")
101115

102116
control_div = html.Div(
@@ -115,32 +129,40 @@ def update_seed(value: int) -> Figure:
115129
plotter.seed = value
116130
return plotter.make_fig()
117131

132+
n_gradients = len(matrix)
133+
gradient_value_outputs: list[Output] = []
134+
for i in range(n_gradients):
135+
gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children"))
136+
gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children"))
137+
118138
@callback(
119139
Output("aggregations-fig", "figure", allow_duplicate=True),
140+
*gradient_value_outputs,
120141
*gradient_slider_inputs,
121142
prevent_initial_call=True,
122143
)
123-
def update_gradient_coordinate(*values: str) -> Figure:
144+
def update_gradient_coordinate(*values: str) -> tuple[Figure, ...]:
124145
values_ = [float(value) for value in values]
125146

147+
display_parts: list[str] = []
126148
for j in range(len(values_) // 2):
127149
angle = values_[2 * j]
128150
r = values_[2 * j + 1]
129151
x, y = angle_to_coord(angle, r)
130152
plotter.matrix[j, 0] = x
131153
plotter.matrix[j, 1] = y
154+
display_parts.append(_format_angle_display(angle))
155+
display_parts.append(_format_length_display(r))
132156

133-
return plotter.make_fig()
157+
return (plotter.make_fig(), *display_parts)
134158

135159
@callback(
136160
Output("aggregations-fig", "figure", allow_duplicate=True),
137161
Input("aggregator-checklist", "value"),
138162
prevent_initial_call=True,
139163
)
140164
def update_aggregators(value: list[str]) -> Figure:
141-
aggregator_keys = value
142-
new_aggregators = [aggregators_dict[key] for key in aggregator_keys]
143-
plotter.aggregators = new_aggregators
165+
plotter.selected_keys = list(value)
144166
return plotter.make_fig()
145167

146168
Timer(1, open_browser).start()
@@ -173,11 +195,56 @@ def make_gradient_div(
173195
style={"width": "250px"},
174196
)
175197

198+
label_style: dict[str, str | int] = {
199+
"display": "inline-block",
200+
"width": "52px",
201+
"margin-right": "8px",
202+
"vertical-align": "middle",
203+
}
204+
value_style: dict[str, str] = {
205+
"display": "inline-block",
206+
"margin-left": "10px",
207+
"min-width": "140px",
208+
"font-family": "monospace",
209+
"font-size": "13px",
210+
"vertical-align": "middle",
211+
}
212+
row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"}
176213
div = html.Div(
177214
[
178-
html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}),
179-
angle_input,
180-
r_input,
215+
dcc.Markdown(
216+
f"$g_{{{i + 1}}}$",
217+
mathjax=True,
218+
style={
219+
"margin": "0 0 6px 0",
220+
"font-weight": "bold",
221+
"display": "block",
222+
},
223+
),
224+
html.Div(
225+
[
226+
html.Span("Angle", style=label_style),
227+
angle_input,
228+
html.Span(
229+
id=f"g{i + 1}-angle-value",
230+
children=_format_angle_display(angle),
231+
style=value_style,
232+
),
233+
],
234+
style=row_style,
235+
),
236+
html.Div(
237+
[
238+
html.Span("Length", style=label_style),
239+
r_input,
240+
html.Span(
241+
id=f"g{i + 1}-length-value",
242+
children=_format_length_display(r),
243+
style=value_style,
244+
),
245+
],
246+
style={**row_style, "margin-bottom": "12px"},
247+
),
181248
],
182249
)
183250
return div, angle_input, r_input

0 commit comments

Comments
 (0)