Skip to content

Commit a1c3565

Browse files
committed
api: add support for no interp (interp_order=0)
1 parent 2413d19 commit a1c3565

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

devito/types/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,11 +1067,11 @@ def _evaluate(self, **kwargs):
10671067
"""
10681068
mapper = self._grid_map
10691069
subs = mapper.pop('subs', {})
1070+
io = self.interp_order
10701071
# Average values if at a location not on the Function's grid
1071-
if not mapper:
1072+
if not mapper or io == 0:
10721073
return self
10731074

1074-
io = self.interp_order
10751075
retval = self.subs({i.subs(subs): self.indices_ref[d]
10761076
for d, i in mapper.items()})
10771077
if self.is_harmonic:

devito/types/dense.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,8 +1095,6 @@ def __init_finalize__(self, *args, **kwargs):
10951095
interp_order = kwargs.get('interp_order', 2)
10961096
if not is_integer(interp_order):
10971097
raise TypeError("`interp_order` must be an integer")
1098-
elif interp_order < 1:
1099-
raise ValueError("`interp_order` must be >= 2")
11001098
elif interp_order > self._space_order and self._space_order > 1:
11011099
raise ValueError("`interp_order` must be <= `space_order`")
11021100
self._interp_order = interp_order
@@ -1121,7 +1119,7 @@ def _fd_priority(self):
11211119
return 1 if self.staggered.on_node else 2
11221120

11231121
def _eval_at(self, func):
1124-
if self.staggered == func.staggered:
1122+
if self.staggered == func.staggered or self.interp_order == 0:
11251123
return self
11261124

11271125
mapper = {}

tests/test_differentiable.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,15 @@ def test_avg_mode(ndim, io):
152152
assert sympy.simplify(b_avg.args[0] - expected) == 0
153153
assert isinstance(b_avg, SafeInv)
154154
assert b_avg.base == b
155+
156+
157+
def test_no_interp():
158+
grid = Grid((10, 10))
159+
x = grid.dimensions[0]
160+
a = Function(name="a", grid=grid, staggered=NODE, interp_order=0)
161+
sa = Function(name="as", grid=grid, staggered=x)
162+
163+
assert a._eval_at(sa) == a
164+
assert sa._eval_at(a) == sa._subs(x, x - x.spacing/2)
165+
assert (a*sa)._eval_at(sa) == a*sa
166+
assert (a + sa)._eval_at(sa) == a + sa

0 commit comments

Comments
 (0)