Skip to content

Commit 3a8e684

Browse files
committed
Add intdiv_c and mod_c
* Note that in python, x % 0 raises ZeroDivisionError. The implementation of mod_c matches this behavior when t2 is the zero vector.
1 parent 96c54e4 commit 3a8e684

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

src/torchjd/sparse/_linalg.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,50 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
2222

2323
# TODO: Verify that the round operation cannot fail
2424
return X_rounded.to(torch.int64)
25+
26+
27+
def mod_c(t1: Tensor, t2: Tensor) -> Tensor:
28+
"""
29+
Computes the combined modulo r = t1 %c t2, such that
30+
t1 = d * t2 + r with d = t1 //c t2 and
31+
0 <= r[i] <= t1[i] for all i.
32+
33+
:param t1: Non-negative integer vector.
34+
:param t2: Non-negative integer vector.
35+
36+
Examples:
37+
[8, 12]^T %c [2, 3]^T = [0, 0]^T
38+
[8, 12]^T %c [2, 4]^T = [2, 0]^T
39+
[8, 12]^T %c [3, 3]^T = [2, 6]^T
40+
[8, 12]^T %c [2, 0]^T = [0, 12]^T
41+
[8, 12]^T %c [0, 2]^T = [8, 0]^T
42+
[8, 12]^T %c [0, 0]^T => ZeroDivisionError
43+
"""
44+
45+
return t1 - intdiv_c(t1, t2) * t2
46+
47+
48+
def intdiv_c(t1: Tensor, t2: Tensor) -> Tensor:
49+
"""
50+
Computes the combined integer division d = t1 // t2, such that
51+
t1 = d * t2 + r with r = t1 %c t2
52+
0 <= r[i] <= t1[i] for all i.
53+
54+
:param t1: Non-negative integer vector.
55+
:param t2: Non-negative integer vector.
56+
57+
Examples:
58+
[8, 12]^T //c [2, 3]^T = 4
59+
[8, 12]^T //c [2, 4]^T = 3
60+
[8, 12]^T //c [3, 3]^T = 2
61+
[8, 12]^T //c [2, 0]^T = 4
62+
[8, 12]^T //c [0, 2]^T = 6
63+
[8, 12]^T //c [0, 0]^T => ZeroDivisionError
64+
"""
65+
66+
non_zero_indices = torch.nonzero(t2)
67+
if len(non_zero_indices) == 0:
68+
raise ZeroDivisionError("division by zero")
69+
else:
70+
min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min()
71+
return min_divider

tests/unit/sparse/test_structured_sparse_tensor.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from pytest import mark
2+
from pytest import mark, raises
33
from torch import Tensor, tensor
44
from torch.ops import aten # type: ignore
55
from torch.testing import assert_close
@@ -12,6 +12,7 @@
1212
)
1313
from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim
1414
from torchjd.sparse._coalesce import fix_zero_stride_columns
15+
from torchjd.sparse._linalg import intdiv_c, mod_c
1516
from torchjd.sparse._structured_sparse_tensor import (
1617
StructuredSparseTensor,
1718
fix_ungrouped_dims,
@@ -420,3 +421,49 @@ def test_fix_zero_stride_columns(
420421
physical, strides = fix_zero_stride_columns(physical, strides)
421422
assert torch.equal(physical, expected_physical)
422423
assert torch.equal(strides, expected_strides)
424+
425+
426+
@mark.parametrize(
427+
["t1", "t2", "expected"],
428+
[
429+
(tensor([8, 12]), tensor([2, 3]), tensor([0, 0])),
430+
(tensor([8, 12]), tensor([2, 4]), tensor([2, 0])),
431+
(tensor([8, 12]), tensor([3, 3]), tensor([2, 6])),
432+
(tensor([8, 12]), tensor([2, 0]), tensor([0, 12])),
433+
(tensor([8, 12]), tensor([0, 2]), tensor([8, 0])),
434+
],
435+
)
436+
def test_mod_c(
437+
t1: Tensor,
438+
t2: Tensor,
439+
expected: Tensor,
440+
):
441+
assert torch.equal(mod_c(t1, t2), expected)
442+
443+
444+
def test_mod_c_by_0_raises():
445+
with raises(ZeroDivisionError):
446+
mod_c(tensor([3, 4]), tensor([0, 0]))
447+
448+
449+
@mark.parametrize(
450+
["t1", "t2", "expected"],
451+
[
452+
(tensor([8, 12]), tensor([2, 3]), 4),
453+
(tensor([8, 12]), tensor([2, 4]), 3),
454+
(tensor([8, 12]), tensor([3, 3]), 2),
455+
(tensor([8, 12]), tensor([2, 0]), 4),
456+
(tensor([8, 12]), tensor([0, 2]), 6),
457+
],
458+
)
459+
def test_intdiv_c(
460+
t1: Tensor,
461+
t2: Tensor,
462+
expected: Tensor,
463+
):
464+
assert intdiv_c(t1, t2) == expected
465+
466+
467+
def test_intdiv_c_by_0_raises():
468+
with raises(ZeroDivisionError):
469+
intdiv_c(tensor([3, 4]), tensor([0, 0]))

0 commit comments

Comments
 (0)