Skip to content

Commit c31aaa6

Browse files
committed
api: add support for no interp (interp_order=0)
1 parent 2413d19 commit c31aaa6

3 files changed

Lines changed: 22 additions & 3 deletions

File tree

devito/types/basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,11 @@ def _evaluate(self, **kwargs):
10741074
io = self.interp_order
10751075
retval = self.subs({i.subs(subs): self.indices_ref[d]
10761076
for d, i in mapper.items()})
1077+
1078+
if io == 0:
1079+
# No interpolation, just substitution (e.g nearest grid point)
1080+
return retval
1081+
10771082
if self.is_harmonic:
10781083
retval = retval._inv(retval, safe=self.is_harmonic_safe)
10791084

devito/types/dense.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,8 +1095,6 @@ def __init_finalize__(self, *args, **kwargs):
10951095
interp_order = kwargs.get('interp_order', 2)
10961096
if not is_integer(interp_order):
10971097
raise TypeError("`interp_order` must be an integer")
1098-
elif interp_order < 1:
1099-
raise ValueError("`interp_order` must be >= 2")
11001098
elif interp_order > self._space_order and self._space_order > 1:
11011099
raise ValueError("`interp_order` must be <= `space_order`")
11021100
self._interp_order = interp_order
@@ -1121,7 +1119,7 @@ def _fd_priority(self):
11211119
return 1 if self.staggered.on_node else 2
11221120

11231121
def _eval_at(self, func):
1124-
if self.staggered == func.staggered:
1122+
if self.staggered == func.staggered or self.interp_order == 0:
11251123
return self
11261124

11271125
mapper = {}

tests/test_differentiable.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,19 @@ def test_avg_mode(ndim, io):
152152
assert sympy.simplify(b_avg.args[0] - expected) == 0
153153
assert isinstance(b_avg, SafeInv)
154154
assert b_avg.base == b
155+
156+
157+
def test_no_interp():
158+
grid = Grid((10, 10))
159+
x = grid.dimensions[0]
160+
a = Function(name="a", grid=grid, staggered=NODE, interp_order=0)
161+
sa = Function(name="as", grid=grid, staggered=x)
162+
163+
assert a._eval_at(sa) == a
164+
assert sa._eval_at(a) == sa._subs(x, x - x.spacing/2)
165+
assert (a*sa)._eval_at(sa) == a*sa
166+
assert (a + sa)._eval_at(sa) == a + sa
167+
168+
a_shift = a._subs(x, x + x.spacing / 2)
169+
# Should just do nearest grid point, so shift back to original
170+
assert a_shift.evaluate == a

0 commit comments

Comments
 (0)