Skip to content

feat: static vs symbolic graph mode with NaryEinsum op #658

@shinaoka

Description

@shinaoka

Summary

Add a graph construction mode that controls whether shapes are resolved at build time (static) or deferred to execution time (symbolic), and correspondingly whether N-ary einsum is decomposed eagerly or kept as a single op. Also improve tensor/traced-tensor construction ergonomics.

Current State

  • Ops carry DimExpr (symbolic dimension expressions) for shape fields, evaluated at execution time from actual input shapes.
  • N-ary einsum is decomposed into binary DotGeneral ops at graph build time via ContractionTree + build_einsum_fragment. The contraction path is optimized at build time using concrete shape_hint values.
  • No NaryEinsum op variant exists in StdTensorOp.
  • Creating a Tensor from typed data requires verbose wrapping: Tensor::F64(TypedTensor::from_vec(shape, data)). No From impls or convenience constructors exist. Every test file copies the same f64_tensor helper.

Proposal

1. Tensor construction ergonomics

Three layers of improvement, all following Rust From/Into idioms:

Layer 1: From<TypedTensor<T>> for Tensor

impl From<TypedTensor<f64>> for Tensor {
    fn from(t: TypedTensor<f64>) -> Self { Tensor::F64(t) }
}
// likewise for f32, Complex64, Complex32

Layer 2: Tensor::from_vec<T> generic constructor

impl Tensor {
    pub fn from_vec<T: IntoTensor>(shape: Vec<usize>, data: Vec<T>) -> Self {
        T::wrap(TypedTensor::from_vec(shape, data))
    }
}

// Usage:
let t = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);

Layer 3: TracedTensor::from_vec<T> shorthand

impl TracedTensor {
    pub fn from_vec<T: IntoTensor>(shape: Vec<usize>, data: Vec<T>) -> Self {
        Self::from_tensor(Tensor::from_vec(shape, data))
    }
}

// Before:
let a = TracedTensor::from_tensor(
    Tensor::F64(TypedTensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]))
);
// After:
let a = TracedTensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]);

2. Size mode: static vs symbolic

Size mode is determined per-tensor at construction time and propagates through the graph. Separate constructor functions make the choice explicit:

// --- From concrete tensor ---
let a = TracedTensor::from_tensor_static(tensor_a);   // shape fixed at build time
let b = TracedTensor::from_tensor_symbolic(tensor_b);  // shape deferred to eval time

// --- Placeholder (no data) ---
let c = TracedTensor::static_input(DType::F64, &[200, 300]);  // concrete shape
let d = TracedTensor::symbolic_input(DType::F64, 2);          // rank only

// --- Shorthand with static size (from_vec defaults to static) ---
let e = TracedTensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]);

3. Propagation rule

All inputs static → binary decomposition at build time. Any symbolic input → NaryEinsum op.

let mut engine = Engine::new(CpuBackend::new());

// All static → einsum decomposes into binary DotGeneral at build time
let a = TracedTensor::from_tensor_static(tensor_a);
let b = TracedTensor::from_tensor_static(tensor_b);
let c = TracedTensor::from_tensor_static(tensor_c);
let out = einsum(&mut engine, &[&a, &b, &c], "ij,jk,kl->il")?;
// graph: DotGeneral × 2 (path optimized at build time)

// Any symbolic → NaryEinsum as single op
let a = TracedTensor::symbolic_input(DType::F64, 2);
let b = TracedTensor::from_tensor_static(tensor_b);
let c = TracedTensor::from_tensor_static(tensor_c);
let out = einsum(&mut engine, &[&a, &b, &c], "ij,jk,kl->il")?;
// graph: NaryEinsum { subscripts: "ij,jk,kl->il" } (path optimized at eval time)

Naming convention (Rust best practice)

Separate functions rather than bool/enum parameter — same pattern as Vec::new vs Vec::with_capacity:

Constructor Size mode Required args
from_tensor_static(tensor) static concrete Tensor
from_tensor_symbolic(tensor) symbolic concrete Tensor (shape forgotten)
from_vec(shape, data) static shape + typed data
static_input(dtype, &shape) static dtype + concrete shape
symbolic_input(dtype, rank) symbolic dtype + rank only

Key Design Points

  1. New op variant: Add NaryEinsum { subscripts: String } (or similar) to StdTensorOp.
  2. Size mode propagation: Each TracedTensor carries a is_static: bool. Einsum checks all inputs to decide decomposition strategy.
  3. Execution-time planning: When executing a NaryEinsum op, resolve concrete shapes, run contraction path optimization (via ContractionTree::optimize), then execute the resulting binary contractions.
  4. AD rules: NaryEinsum needs reverse-mode AD rules. Could either:
    • Decompose into binary ops first, then reuse existing DotGeneral AD rules.
    • Define AD directly on the N-ary einsum (more complex but potentially more efficient).
  5. README update: Update the project README to document the static/symbolic size mode API and usage examples.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions