Skip to content

feat: EagerTensor — eager execution with tape-based AD #693

@shinaoka

Description

@shinaoka

Summary

Add EagerTensor type that combines eager (immediate) execution with tape-based automatic differentiation. This fills the missing quadrant in tenferro-rs:

AD なし AD あり
Eager Tensor EagerTensor ← NEW
Traced TracedTensor TracedTensor

Motivation

tensor4all-rs migration

tensor4all-rs currently depends on the old tenferro v1 which has eager + AD
via NativeTensor::requires_grad() / backward() / grad(). The new
tenferro-rs (v2) removed this in favor of graph-only AD (TracedTensor + tidu).

This forces the migration to use TracedTensor for AD, requiring extensive
workarounds for dynamic shapes (symbolic shape_hint, DynamicTruncate,
ShapeGuardContext, etc.). With EagerTensor, the migration becomes a near
1:1 replacement of the old NativeTensor API.

Truncated SVD and data-dependent shapes

TN algorithms use adaptive truncated SVD where the output rank depends on
singular values. In TracedTensor (graph), this requires DynamicTruncate
(#691) or eval barriers. In EagerTensor, it works naturally:

let (u, s, vt) = matrix.svd(&mut ctx)?;   // immediate
let rank = determine_rank(s.data(), rtol, max_rank);  // concrete
let u_k = u.slice(0..rank)?;              // immediate, recorded on tape

N-ary einsum

NaryEinsum (#681) is also useful for EagerTensor. With concrete shapes,
contraction path is planned immediately. The AD rule (subscript manipulation
only, no shape info needed) works identically in eager and traced modes.

Design

Core types

/// Tensor with optional tape attachment for eager AD.
pub struct EagerTensor {
    data: Arc<Tensor>,               // always holds concrete data
    node: Option<TapeNodeId>,        // None = detached, Some = tracked on tape
}

/// Records operations for backward AD pass.
pub struct Tape {
    nodes: Vec<TapeNode>,
}

/// A recorded operation with primal data for backward.
struct TapeNode {
    op: StdTensorOp,
    inputs: Vec<TapeNodeId>,         // indices into tape.nodes
    input_data: Vec<Arc<Tensor>>,    // primal values for AD rules
    n_outputs: usize,
}

pub type TapeNodeId = usize;

Usage pattern

// Without AD (no tape, like bare Tensor)
let a = EagerTensor::new(vec![3, 3], data);
let (u, s, vt) = a.svd(&mut ctx)?;

// With AD (tape records ops)
let mut tape = Tape::new();
let a = tape.track(Tensor::new(vec![3, 3], data));
let b = tape.track(Tensor::new(vec![3, 3], data));
let c = a.matmul(&b, &mut ctx)?;             // exec + tape record
let loss = c.reduce_sum(&[0, 1], &mut ctx)?;  // exec + tape record
let grads = tape.backward(&loss, &[&a, &b])?;  // tape reverse walk

How tape.backward() works

  1. Walk tape.nodes in reverse order (from loss to inputs)
  2. For each node, call op.transpose_rule() from chainrules
    • transpose_rule receives cotangent outputs and primal input data
    • Returns cotangent inputs
    • All concrete (no graph, no compile)
  3. Accumulate cotangents for shared inputs (fan-out → add)
  4. Return gradients for requested inputs
impl Tape {
    pub fn backward(
        &self,
        output: &EagerTensor,
        wrt: &[&EagerTensor],
        ctx: &mut impl TensorBackend,
    ) -> Result<Vec<Tensor>> {
        // Initialize cotangent for output = ones (scalar) or provided seed
        let mut cotangents: Vec<Option<Tensor>> = vec![None; self.nodes.len()];
        cotangents[output.node.unwrap()] = Some(ones_like(&output.data));

        // Reverse walk
        for i in (0..self.nodes.len()).rev() {
            let Some(ref ct_out) = cotangents[i] else { continue };
            let node = &self.nodes[i];

            // Call chainrules transpose_rule
            let ct_inputs = node.op.transpose_rule(
                /* builder: */ ...,  // ← see "Integration with chainrules" below
                &ct_out_slices,
                &input_refs,
                &node.mode,
                &mut ad_ctx,
            );

            // Accumulate into input cotangents
            for (input_id, ct) in node.inputs.iter().zip(ct_inputs) {
                if let Some(ct_val) = ct {
                    // accumulate: cotangents[input_id] += ct_val
                }
            }
        }

        // Extract gradients for requested inputs
        Ok(wrt.iter().map(|t| cotangents[t.node.unwrap()].clone().unwrap()).collect())
    }
}

Integration with chainrules AD rules

The key design question: PrimitiveOp::transpose_rule takes a
FragmentBuilder and emits graph ops. For eager execution, we need
to EXECUTE those ops immediately instead of building a graph.

Two approaches:

Approach A: Eager FragmentBuilder

Create an EagerFragmentBuilder that implements the same interface as
FragmentBuilder but executes ops immediately:

struct EagerFragmentBuilder<'a> {
    ctx: &'a mut dyn TensorBackend,
    values: Vec<Tensor>,  // concrete results
}

impl EagerFragmentBuilder {
    fn add_op(&mut self, op: StdTensorOp, inputs: Vec<ValRef>, mode: OpMode) -> Vec<LocalValId> {
        let input_tensors = self.resolve_inputs(&inputs);
        let result = execute_op(&op, &input_tensors, self.ctx)?;
        let id = self.values.len();
        self.values.push(result);
        vec![id]
    }
}

The AD rules (linearize/transpose_rule) call builder.add_op(...).
With EagerFragmentBuilder, each add_op executes immediately.
Same AD rule code, different builder → different execution model.

Approach B: Build small graph, execute immediately

For each tape node's backward:

  1. Call transpose_rule with a regular FragmentBuilder → small Fragment
  2. Compile + execute the Fragment immediately
  3. Discard the Fragment

This reuses all existing infrastructure but has compile overhead per node.

Approach A is better — no compile overhead, direct execution.

However, Approach A requires that FragmentBuilder be abstracted into a
trait (currently it's a concrete struct in computegraph-rs). This may
require changes to computegraph-rs.

Approach C: Execute transpose_rule outputs directly

For simple AD rules (most elementwise ops), transpose_rule emits 1-2 ops.
We can intercept these and execute them directly without a builder:

// Instead of going through FragmentBuilder, execute the transpose rule
// by pattern-matching on the ops it would emit.
fn eager_transpose(
    op: &StdTensorOp,
    cotangent: &Tensor,
    primal_inputs: &[&Tensor],
    ctx: &mut impl TensorBackend,
) -> Vec<Option<Tensor>> {
    // Direct implementation using Tensor eager methods
    match op {
        StdTensorOp::Add => vec![Some(cotangent.clone()), Some(cotangent.clone())],
        StdTensorOp::Mul => {
            let ct_lhs = cotangent.mul(&primal_inputs[1].conj(), ctx)?;
            let ct_rhs = primal_inputs[0].conj().mul(cotangent, ctx)?;
            vec![Some(ct_lhs), Some(ct_rhs)]
        }
        // ... dispatch to chainrules transpose rules
    }
}

This duplicates the AD rule logic, which we want to avoid.

Recommended: Approach A with a trait abstraction for the builder.

JVP (forward-mode AD)

impl Tape {
    /// Forward-mode: compute function value + JVP simultaneously
    pub fn jvp(
        f: impl FnOnce(&mut Tape, &mut impl TensorBackend) -> EagerTensor,
        x: &EagerTensor,
        tangent: &Tensor,
        ctx: &mut impl TensorBackend,
    ) -> (Tensor, Tensor) {
        // Forward execution with dual numbers:
        // each op: (primal, tangent) → (result, tangent_result)
        // Uses linearize() from chainrules
    }
}

HVP (Hessian-vector product)

// HVP = JVP through VJP
let hvp = Tape::jvp(
    |tape, ctx| {
        let grad = tape.backward(&loss, &[&x], ctx)?;
        grad[0]
    },
    &x, &v, &mut ctx,
);

This composes naturally: JVP of the backward pass.

Checkpoint integration

let mut tape = Tape::new();
for step in 0..L {
    let matrix = contract(&tensors, &mut ctx)?;
    let (u, s, vt) = matrix.svd(&mut ctx)?;
    let rank = determine_rank(s.data(), rtol, max_rank);
    let u_k = u.slice(0..rank)?;

    // Checkpoint: save this point for memory efficiency
    tape.checkpoint();  // frees intermediate primal data before this point
                        // backward will recompute from checkpoint
}
let grads = tape.backward(&loss, &inputs, &mut ctx)?;

Checkpoint in tape context is simpler than in graph context:

  • Forward: primal data is saved at checkpoint, freed in between
  • Backward: recompute from checkpoint to get intermediate primals

Relationship to existing types

             chainrules (PrimitiveOp: linearize, transpose_rule)
                  ↑ AD rules are shared
             StdTensorOp (op definitions are shared)
            /              |              \
    Tensor           EagerTensor      TracedTensor
    (no AD)          (eager + AD)     (graph + AD)
    CPU eager        CPU eager+tape   graph compile
    #684             ← THIS ISSUE     existing

What this does NOT change

Prerequisites

  • Abstract FragmentBuilder into a trait (or Approach A equivalent)
    so that AD rules can execute eagerly
  • Decide where EagerTensor lives: tenferro-tensor or tenferro

Implementation steps

  1. Tape struct + TapeNode + TapeNodeId (data structures)
  2. EagerTensor struct (wraps Arc<Tensor> + optional tape ref)
  3. Op recording: each eager method records to tape when tracked
  4. tape.backward() — reverse walk with chainrules transpose_rule
  5. tape.jvp() — forward walk with chainrules linearize
  6. Checkpoint support in Tape
  7. Tests: gradient correctness via finite-difference for all ops
  8. HVP test: forward-over-reverse

Testing

  • EagerTensor without tape: behaves like Tensor
  • EagerTensor with tape: matmul + reduce_sum → backward gives correct gradient
  • AD through SVD (eager): finite-difference validation
  • AD through truncated SVD: SVD → slice → downstream ops → backward
  • NaryEinsum AD (eager): subscript-based transpose
  • HVP: JVP through backward
  • Checkpoint: √N spacing → O(√N) memory
  • Comparison with TracedTensor AD: same numerical results

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions