Skip to content

DotGeneralConfig lhs_rank/rhs_rank fields become stale, causing runtime RankMismatch panics #664

@shinaoka

Description

@shinaoka

Summary

DotGeneralConfig.lhs_rank / rhs_rank are redundant metadata that duplicate the actual operand rank. They become stale during compilation, causing runtime panics on complex N-ary einsum contractions (84+ tensors).

Symptoms

When running einsum benchmarks with many tensors, the following panics occur:

  • RankMismatch(3, 6) — a Permute instruction has perm.len() == 3 but the input tensor has rank 6
  • index out of bounds: the len is 3 but the index is 3DotGeneralConfig.lhs_batch_dims contains index 3 but the tensor has rank 3 (indices 0-2)
  • Segfaults from cascading invalid memory access

Root Cause

  1. Graph construction: DotGeneralConfig stores lhs_rank/rhs_rank set from logical operand shapes at graph-build time
  2. Output rank computation: TracedTensor::dot_general (traced.rs:759) computes output rank from these stored ranks → downstream Permute perm lengths are based on this computed rank
  3. Runtime divergence: The GEMM backend (gemm/mod.rs) uses actual tensor shape.len() to compute free dims and output shape, which may differ from the stored rank
  4. Mismatch: Downstream Permute instructions have perm lengths based on the (wrong) logical output rank, but receive tensors with a different actual rank

The StableHLO IR has no shape table (StableHloInstruction carries no operand shape/rank metadata — stablehlo.rs:111), so compiler passes cannot detect or repair stale rank fields.

Affected Code

Location Role
tenferro-tensor/src/config.rs:22-24 Stores lhs_rank/rhs_rank
tenferro/src/traced.rs:759 Computes output rank from stored ranks
tenferro-einsum/src/builder.rs:272 Sets rank from LabeledVal.shape.len()
tenferro-tensor/src/cpu/gemm/mod.rs:103 Uses actual tensor rank at runtime
tenferro/src/exec.rs:176 Applies Permute with stale perm length

Reproduction

Run the einsum benchmark with instances having 84+ tensors (e.g., lm_batch_likelihood_brackets_4_4d, str_matrix_chain_multiplication_100).

Recommended Fix

  1. Add shape/rank tracking to StableHLO (side table or per-instruction metadata)
  2. Remove lhs_rank/rhs_rank from DotGeneralConfig; derive rank from the shape table at every usage site
  3. Add validation at rewrite boundaries: perm.len() == producer_rank, all dim indices in range, batch/contract sets disjoint

Why simple removal is insufficient today

lhs_rank/rhs_rank cannot be simply deleted because:

  • AD rules (ad/contraction.rs) need operand rank but only receive ValRefs, not shapes
  • Compiler passes (compiler.rs) need rank but StableHloInstruction has no shape metadata
  • max(dim_indices) + 1 does not recover the rank (e.g., rank 4 with contracting_dims=[1], batch_dims=[] has free dims [0,2,3] but max=1+1=2≠4)

Additional Finding

transpose_folding has a separate correctness issue: it remaps batch/contracting dims through the transpose permutation but does not account for free-dim order changes. This can silently change the DotGeneral output axis order without changing rank. (Not the cause of the RankMismatch panics, but should be fixed.)

Related PRs

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions