Skip to content

Commit 6e0dbcd

Browse files
shinaokaclaude
andauthored
feat: add alias map to differentiate() for checkpoint support (#17)
Add `aliases` parameter to `differentiate()` that maps input keys to derived keys. When differentiation reaches an aliased input leaf, it follows the alias to continue traversal through the old computation graph. This enables gradient flow through checkpoint boundaries. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ae8eb21 commit 6e0dbcd

7 files changed

Lines changed: 90 additions & 10 deletions

File tree

src/differentiate.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ use crate::LinearFragment;
2121
///
2222
/// let view = resolve(vec![primal_fragment]);
2323
/// let mut ctx = ();
24-
/// let linear = differentiate(&view, &[output_key], &[input_key], 1, &mut ctx);
24+
/// let aliases = std::collections::HashMap::new();
25+
/// let linear = differentiate(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases);
2526
/// assert_eq!(linear.tangent_outputs.len(), 1);
2627
/// ```
2728
pub fn differentiate<Op: PrimitiveOp>(
@@ -30,12 +31,13 @@ pub fn differentiate<Op: PrimitiveOp>(
3031
wrt: &[Op::InputKey],
3132
pass: DiffPassId,
3233
ctx: &mut Op::ADContext,
34+
aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
3335
) -> LinearFragment<Op>
3436
where
3537
Op::InputKey: ADKey,
3638
{
3739
let mut builder = FragmentBuilder::<Op>::new();
38-
let topo_keys = topological_order(view, outputs);
40+
let topo_keys = topological_order(view, outputs, aliases);
3941
let mut tangent_env: HashMap<GlobalValKey<Op>, Option<LocalValId>> = HashMap::new();
4042
let mut processed_ops = HashSet::new();
4143

@@ -57,8 +59,13 @@ where
5759
};
5860

5961
match val_def {
60-
ValDef::Input { .. } => {
61-
tangent_env.insert(key, None);
62+
ValDef::Input { key: ref input_key } => {
63+
if let Some(aliased_key) = aliases.get(input_key) {
64+
let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
65+
tangent_env.insert(key, aliased_tangent);
66+
} else {
67+
tangent_env.insert(key, None);
68+
}
6269
}
6370
ValDef::Produced {
6471
op,
@@ -136,21 +143,31 @@ fn output_keys<Op: GraphOp>(op_key: &GlobalOpKey<Op>, n_outputs: usize) -> Vec<G
136143
fn topological_order<Op: GraphOp>(
137144
view: &ResolvedView<Op>,
138145
outputs: &[GlobalValKey<Op>],
146+
aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
139147
) -> Vec<GlobalValKey<Op>> {
140148
fn visit<Op: GraphOp>(
141149
key: &GlobalValKey<Op>,
142150
view: &ResolvedView<Op>,
151+
aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
143152
visited: &mut HashSet<GlobalValKey<Op>>,
144153
order: &mut Vec<GlobalValKey<Op>>,
145154
) {
146155
if !visited.insert(key.clone()) {
147156
return;
148157
}
149158

150-
if let Some(ValDef::Produced { input_keys, .. }) = view.resolve_val(key) {
151-
for input_key in input_keys {
152-
visit(&input_key, view, visited, order);
159+
match view.resolve_val(key) {
160+
Some(ValDef::Produced { input_keys, .. }) => {
161+
for input_key in input_keys {
162+
visit(&input_key, view, aliases, visited, order);
163+
}
164+
}
165+
Some(ValDef::Input { key: input_key }) => {
166+
if let Some(aliased_key) = aliases.get(&input_key) {
167+
visit(aliased_key, view, aliases, visited, order);
168+
}
153169
}
170+
None => {}
154171
}
155172

156173
order.push(key.clone());
@@ -159,7 +176,7 @@ fn topological_order<Op: GraphOp>(
159176
let mut visited = HashSet::new();
160177
let mut order = Vec::new();
161178
for output_key in outputs {
162-
visit(output_key, view, &mut visited, &mut order);
179+
visit(output_key, view, aliases, &mut visited, &mut order);
163180
}
164181
order
165182
}

tests/adcontext_tests.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
#[allow(dead_code)]
23
mod common;
34

@@ -147,7 +148,7 @@ fn differentiate_threads_ctx_to_all_ops() {
147148
let wrt = vec![ck("x")];
148149

149150
let mut ctx = CountingContext::default();
150-
let _linear = tidu::differentiate(&view, &[output_key], &wrt, 1, &mut ctx);
151+
let _linear = tidu::differentiate(&view, &[output_key], &wrt, 1, &mut ctx, &HashMap::new());
151152

152153
assert_eq!(
153154
ctx.linearize_count, 2,
@@ -162,7 +163,7 @@ fn transpose_threads_ctx_to_all_ops() {
162163
let wrt = vec![ck("x")];
163164

164165
let mut ctx = CountingContext::default();
165-
let linear = tidu::differentiate(&view, &[output_key], &wrt, 1, &mut ctx);
166+
let linear = tidu::differentiate(&view, &[output_key], &wrt, 1, &mut ctx, &HashMap::new());
166167

167168
ctx.linearize_count = 0;
168169
ctx.transpose_count = 0;

tests/complex_ad_tests.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
mod common;
23

34
use std::sync::Arc;
@@ -433,6 +434,7 @@ fn jvp_conj_z() {
433434
&[ck("z")],
434435
1,
435436
&mut (),
437+
&HashMap::new(),
436438
);
437439

438440
let dy_key = tangent_output_key(&linear, 0).expect("active tangent output");
@@ -455,6 +457,7 @@ fn vjp_conj_z() {
455457
&[ck("z")],
456458
2,
457459
&mut (),
460+
&HashMap::new(),
458461
);
459462
let transposed = transpose(&linear, &mut ());
460463

@@ -478,6 +481,7 @@ fn jvp_z_times_w() {
478481
&[ck("z"), ck("w")],
479482
3,
480483
&mut (),
484+
&HashMap::new(),
481485
);
482486

483487
let dy_key = tangent_output_key(&linear, 0).expect("active tangent output");
@@ -510,6 +514,7 @@ fn vjp_c_times_z_uses_conjugated_constant() {
510514
&[ck("z")],
511515
4,
512516
&mut (),
517+
&HashMap::new(),
513518
);
514519
let transposed = transpose(&linear, &mut ());
515520

@@ -537,6 +542,7 @@ fn vjp_abs_squared_returns_two_z() {
537542
&[ck("z")],
538543
5,
539544
&mut (),
545+
&HashMap::new(),
540546
);
541547
let transposed = transpose(&linear, &mut ());
542548

@@ -561,6 +567,7 @@ fn numerical_gradient_exp_z_matches_vjp_for_real_and_imag_losses() {
561567
&[ck("z")],
562568
6,
563569
&mut (),
570+
&HashMap::new(),
564571
);
565572
let transposed = transpose(&linear, &mut ());
566573
let ct_y_key = tangent_input_key(&transposed, 0);

tests/edge_case_tests.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
mod common;
23

34
use std::sync::Arc;
@@ -1146,6 +1147,7 @@ fn multi_output_sincos_jvp_sum_matches_expected() {
11461147
&[sk("x")],
11471148
1001,
11481149
&mut (),
1150+
&HashMap::new(),
11491151
);
11501152

11511153
let dy_key = tangent_output_key(&linear, 0).expect("active tangent output");
@@ -1170,6 +1172,7 @@ fn multi_output_sincos_vjp_matches_expected() {
11701172
&[sk("x")],
11711173
1002,
11721174
&mut (),
1175+
&HashMap::new(),
11731176
);
11741177
let transposed = transpose(&linear, &mut ());
11751178

@@ -1199,6 +1202,7 @@ fn multi_output_sincos_adjoint_consistency() {
11991202
&[sk("x")],
12001203
1003,
12011204
&mut (),
1205+
&HashMap::new(),
12021206
);
12031207
let dy_sin_key = tangent_output_key(&linear, 0).expect("active tangent output for sin");
12041208
let dy_cos_key = tangent_output_key(&linear, 1).expect("active tangent output for cos");
@@ -1249,6 +1253,7 @@ fn deep_chain_exp_10x() {
12491253
&[sk("x")],
12501254
1101,
12511255
&mut (),
1256+
&HashMap::new(),
12521257
);
12531258
let dy_key = tangent_output_key(&linear, 0).expect("active tangent output");
12541259
let dx_key = tangent_input_key(&linear, 0);
@@ -1306,6 +1311,7 @@ fn third_order_x_cubed() {
13061311
&[sk("x")],
13071312
1201,
13081313
&mut (),
1314+
&HashMap::new(),
13091315
);
13101316
let dy_key = tangent_output_key(&linear_1, 0).expect("active first-order tangent output");
13111317
let dx1_key = tangent_input_key(&linear_1, 0);
@@ -1317,6 +1323,7 @@ fn third_order_x_cubed() {
13171323
&[sk("x")],
13181324
1202,
13191325
&mut (),
1326+
&HashMap::new(),
13201327
);
13211328
let d2y_key = tangent_output_key(&linear_2, 0).expect("active second-order tangent output");
13221329
let dx2_key = tangent_input_key(&linear_2, 0);
@@ -1332,6 +1339,7 @@ fn third_order_x_cubed() {
13321339
&[sk("x")],
13331340
1203,
13341341
&mut (),
1342+
&HashMap::new(),
13351343
);
13361344
let d3y_key = tangent_output_key(&linear_3, 0).expect("active third-order tangent output");
13371345
let dx3_key = tangent_input_key(&linear_3, 0);
@@ -1364,6 +1372,7 @@ fn fourth_order_x_fourth() {
13641372
&[sk("x")],
13651373
1211,
13661374
&mut (),
1375+
&HashMap::new(),
13671376
);
13681377
let dy_key = tangent_output_key(&linear_1, 0).expect("active first-order tangent output");
13691378
let dx1_key = tangent_input_key(&linear_1, 0);
@@ -1375,6 +1384,7 @@ fn fourth_order_x_fourth() {
13751384
&[sk("x")],
13761385
1212,
13771386
&mut (),
1387+
&HashMap::new(),
13781388
);
13791389
let d2y_key = tangent_output_key(&linear_2, 0).expect("active second-order tangent output");
13801390
let dx2_key = tangent_input_key(&linear_2, 0);
@@ -1390,6 +1400,7 @@ fn fourth_order_x_fourth() {
13901400
&[sk("x")],
13911401
1213,
13921402
&mut (),
1403+
&HashMap::new(),
13931404
);
13941405
let d3y_key = tangent_output_key(&linear_3, 0).expect("active third-order tangent output");
13951406
let dx3_key = tangent_input_key(&linear_3, 0);
@@ -1406,6 +1417,7 @@ fn fourth_order_x_fourth() {
14061417
&[sk("x")],
14071418
1214,
14081419
&mut (),
1420+
&HashMap::new(),
14091421
);
14101422
let d4y_key = tangent_output_key(&linear_4, 0).expect("active fourth-order tangent output");
14111423
let dx4_key = tangent_input_key(&linear_4, 0);
@@ -1440,6 +1452,7 @@ fn third_order_for_then_f() {
14401452
&[sk("x")],
14411453
1221,
14421454
&mut (),
1455+
&HashMap::new(),
14431456
);
14441457
let transposed = transpose(&linear, &mut ());
14451458
let ct_x_key = tangent_output_key(&transposed, 0).expect("active cotangent output");
@@ -1452,6 +1465,7 @@ fn third_order_for_then_f() {
14521465
&[sk("x")],
14531466
1222,
14541467
&mut (),
1468+
&HashMap::new(),
14551469
);
14561470
let d_ct_x_key = tangent_output_key(&linear_2, 0).expect("active forward-over-reverse output");
14571471
let dx2_key = tangent_input_key(&linear_2, 0);
@@ -1467,6 +1481,7 @@ fn third_order_for_then_f() {
14671481
&[sk("x")],
14681482
1223,
14691483
&mut (),
1484+
&HashMap::new(),
14701485
);
14711486
let d2_ct_x_key = tangent_output_key(&linear_3, 0).expect("active third-order output");
14721487
let dx3_key = tangent_input_key(&linear_3, 0);
@@ -1503,6 +1518,7 @@ fn fofof_vector_x_cubed() {
15031518
&[vk("x")],
15041519
1401,
15051520
&mut (),
1521+
&HashMap::new(),
15061522
);
15071523
let dy_key = tangent_output_key(&linear_1, 0).expect("active first-order tangent output");
15081524
let dx1_key = tangent_input_key(&linear_1, 0);
@@ -1514,6 +1530,7 @@ fn fofof_vector_x_cubed() {
15141530
&[vk("x")],
15151531
1402,
15161532
&mut (),
1533+
&HashMap::new(),
15171534
);
15181535
let d2y_key = tangent_output_key(&linear_2, 0).expect("active second-order tangent output");
15191536
let dx2_key = tangent_input_key(&linear_2, 0);
@@ -1529,6 +1546,7 @@ fn fofof_vector_x_cubed() {
15291546
&[vk("x")],
15301547
1403,
15311548
&mut (),
1549+
&HashMap::new(),
15321550
);
15331551
let d3y_key = tangent_output_key(&linear_3, 0).expect("active third-order tangent output");
15341552
let dx3_key = tangent_input_key(&linear_3, 0);
@@ -1570,6 +1588,7 @@ fn fof_vector_adjoint_consistency() {
15701588
&[vk("x")],
15711589
1411,
15721590
&mut (),
1591+
&HashMap::new(),
15731592
);
15741593
let dy_key = tangent_output_key(&linear_1, 0).expect("active first-order tangent output");
15751594
let dx1_fof_key = tangent_input_key(&linear_1, 0);
@@ -1582,6 +1601,7 @@ fn fof_vector_adjoint_consistency() {
15821601
&[vk("x")],
15831602
1412,
15841603
&mut (),
1604+
&HashMap::new(),
15851605
);
15861606
let d2y_key = tangent_output_key(&linear_2, 0).expect("active second-order tangent output");
15871607
let dx2_key = tangent_input_key(&linear_2, 0);
@@ -1609,6 +1629,7 @@ fn fof_vector_adjoint_consistency() {
16091629
&[vk("x")],
16101630
1413,
16111631
&mut (),
1632+
&HashMap::new(),
16121633
);
16131634
let d_ct_x_key = tangent_output_key(&linear_3, 0).expect("active forward-over-reverse output");
16141635
let dx1_for_key = tangent_input_key(&linear_3, 0);
@@ -1637,6 +1658,7 @@ fn complex_vector_jvp_conj_elementwise() {
16371658
&[cvk("z")],
16381659
1301,
16391660
&mut (),
1661+
&HashMap::new(),
16401662
);
16411663

16421664
let dy_key = tangent_output_key(&linear, 0).expect("active tangent output");
@@ -1665,6 +1687,7 @@ fn complex_vector_vjp_sum_abs_squared() {
16651687
&[cvk("z")],
16661688
1302,
16671689
&mut (),
1690+
&HashMap::new(),
16681691
);
16691692
let transposed = transpose(&linear, &mut ());
16701693

@@ -1696,6 +1719,7 @@ fn complex_vector_adjoint_consistency() {
16961719
&[cvk("z")],
16971720
1303,
16981721
&mut (),
1722+
&HashMap::new(),
16991723
);
17001724
let dy_key = tangent_output_key(&linear, 0).expect("active tangent output");
17011725
let dz_key = tangent_input_key(&linear, 0);

0 commit comments

Comments
 (0)