Summary
DotGeneralConfig.lhs_rank / rhs_rank are redundant metadata that duplicate the actual operand rank. They become stale during compilation, causing runtime panics on complex N-ary einsum contractions (84+ tensors).
Symptoms
When running einsum benchmarks with many tensors, the following panics occur:
RankMismatch(3, 6) — a Permute instruction has perm.len() == 3 but the input tensor has rank 6
index out of bounds: the len is 3 but the index is 3 — DotGeneralConfig.lhs_batch_dims contains index 3 but the tensor has rank 3 (indices 0-2)
- Segfaults from cascading invalid memory access
Root Cause
- Graph construction:
DotGeneralConfig stores lhs_rank/rhs_rank set from logical operand shapes at graph-build time
- Output rank computation:
TracedTensor::dot_general (traced.rs:759) computes output rank from these stored ranks → downstream Permute perm lengths are based on this computed rank
- Runtime divergence: The GEMM backend (
gemm/mod.rs) uses actual tensor shape.len() to compute free dims and output shape, which may differ from the stored rank
- Mismatch: Downstream
Permute instructions have perm lengths based on the (wrong) logical output rank, but receive tensors with a different actual rank
The StableHLO IR has no shape table (StableHloInstruction carries no operand shape/rank metadata — stablehlo.rs:111), so compiler passes cannot detect or repair stale rank fields.
Affected Code
| Location |
Role |
tenferro-tensor/src/config.rs:22-24 |
Stores lhs_rank/rhs_rank |
tenferro/src/traced.rs:759 |
Computes output rank from stored ranks |
tenferro-einsum/src/builder.rs:272 |
Sets rank from LabeledVal.shape.len() |
tenferro-tensor/src/cpu/gemm/mod.rs:103 |
Uses actual tensor rank at runtime |
tenferro/src/exec.rs:176 |
Applies Permute with stale perm length |
Reproduction
Run the einsum benchmark with instances having 84+ tensors (e.g., lm_batch_likelihood_brackets_4_4d, str_matrix_chain_multiplication_100).
Recommended Fix
- Add shape/rank tracking to StableHLO (side table or per-instruction metadata)
- Remove
lhs_rank/rhs_rank from DotGeneralConfig; derive rank from the shape table at every usage site
- Add validation at rewrite boundaries:
perm.len() == producer_rank, all dim indices in range, batch/contract sets disjoint
Why simple removal is insufficient today
lhs_rank/rhs_rank cannot be simply deleted because:
- AD rules (
ad/contraction.rs) need operand rank but only receive ValRefs, not shapes
- Compiler passes (
compiler.rs) need rank but StableHloInstruction has no shape metadata
max(dim_indices) + 1 does not recover the rank (e.g., rank 4 with contracting_dims=[1], batch_dims=[] has free dims [0,2,3] but max=1+1=2≠4)
Additional Finding
transpose_folding has a separate correctness issue: it remaps batch/contracting dims through the transpose permutation but does not account for free-dim order changes. This can silently change the DotGeneral output axis order without changing rank. (Not the cause of the RankMismatch panics, but should be fixed.)
Related PRs
Summary
DotGeneralConfig.lhs_rank/rhs_rankare redundant metadata that duplicate the actual operand rank. They become stale during compilation, causing runtime panics on complex N-ary einsum contractions (84+ tensors).Symptoms
When running einsum benchmarks with many tensors, the following panics occur:
RankMismatch(3, 6)— aPermuteinstruction hasperm.len() == 3but the input tensor has rank 6index out of bounds: the len is 3 but the index is 3—DotGeneralConfig.lhs_batch_dimscontains index 3 but the tensor has rank 3 (indices 0-2)Root Cause
DotGeneralConfigstoreslhs_rank/rhs_rankset from logical operand shapes at graph-build timeTracedTensor::dot_general(traced.rs:759) computes output rank from these stored ranks → downstreamPermuteperm lengths are based on this computed rankgemm/mod.rs) uses actual tensorshape.len()to compute free dims and output shape, which may differ from the stored rankPermuteinstructions have perm lengths based on the (wrong) logical output rank, but receive tensors with a different actual rankThe StableHLO IR has no shape table (
StableHloInstructioncarries no operand shape/rank metadata —stablehlo.rs:111), so compiler passes cannot detect or repair stale rank fields.Affected Code
tenferro-tensor/src/config.rs:22-24lhs_rank/rhs_ranktenferro/src/traced.rs:759tenferro-einsum/src/builder.rs:272LabeledVal.shape.len()tenferro-tensor/src/cpu/gemm/mod.rs:103tenferro/src/exec.rs:176Permutewith stale perm lengthReproduction
Run the einsum benchmark with instances having 84+ tensors (e.g.,
lm_batch_likelihood_brackets_4_4d,str_matrix_chain_multiplication_100).Recommended Fix
lhs_rank/rhs_rankfromDotGeneralConfig; derive rank from the shape table at every usage siteperm.len() == producer_rank, all dim indices in range, batch/contract sets disjointWhy simple removal is insufficient today
lhs_rank/rhs_rankcannot be simply deleted because:ad/contraction.rs) need operand rank but only receiveValRefs, not shapescompiler.rs) need rank butStableHloInstructionhas no shape metadatamax(dim_indices) + 1does not recover the rank (e.g., rank 4 withcontracting_dims=[1],batch_dims=[]has free dims[0,2,3]butmax=1+1=2≠4)Additional Finding
transpose_foldinghas a separate correctness issue: it remaps batch/contracting dims through the transpose permutation but does not account for free-dim order changes. This can silently change the DotGeneral output axis order without changing rank. (Not the cause of the RankMismatch panics, but should be fixed.)Related PRs
canonical_gemm_layoutdot_decomposer