Skip to content

Commit 8d1995d

Browse files
authored
Merge pull request #2892 from devitocodes/halo-comm-order
compiler: fix halo placement for non out dim exchange
2 parents d9dd186 + 076af37 commit 8d1995d

7 files changed

Lines changed: 44 additions & 18 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ def callback(self, clusters, prefix, seen=None):
442442
d = prefix[-1].dim
443443

444444
# Construct a representation of the halo accesses
445-
processed = []
446-
for c in clusters:
445+
processed = list(clusters)
446+
for n, c in enumerate(clusters):
447447
if c.properties.is_sequential(d) or \
448448
c in seen:
449449
continue
@@ -480,10 +480,16 @@ def callback(self, clusters, prefix, seen=None):
480480

481481
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
482482

483-
processed.append(halo_touch)
484-
seen.update({halo_touch, c})
483+
# Insert `halo_touch` at the top of the IterationSpace within which
484+
# `c` is scheduled
485+
index = 0
486+
for i in reversed(range(n)):
487+
if not processed[i].ispace.is_subset(c.ispace):
488+
index = i + 1
489+
break
490+
processed.insert(index, halo_touch)
485491

486-
processed.extend(clusters)
492+
seen.update({halo_touch, c})
487493

488494
return processed
489495

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def nprocs_local(self):
261261

262262
@property
263263
def topology(self):
264-
return self._topology
264+
return DimensionTuple(*self._topology, getters=self.dimensions)
265265

266266
@property
267267
def topology_logical(self):

devito/mpi/halo_scheme.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,12 +617,11 @@ def classify(exprs, ispace):
617617
f"scheme for `{f}` along Dimension `{d}`")
618618
elif hl.pop() is STENCIL:
619619
halos.append(Halo(d, s))
620-
else:
620+
elif d._defines & set(ispace.itdims):
621621
raw_loc_indices[d].append(s)
622622

623623
loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
624624
ispace.directions)
625-
626625
mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
627626

628627
return mapper

devito/symbolics/search.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> N
6060
def _next(self, expr: Expression) -> Iterable[Expression]:
6161
if self.deep and expr.is_Indexed:
6262
return expr.indices
63+
elif self.deep and q_dimension(expr):
64+
return expr.bound_symbols
6365
elif q_leaf(expr):
6466
return ()
6567
return expr.args

tests/test_dle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def test_dynamic_nthreads(self):
755755
('[Eq(f, 2*f)]', [2, 0, 0], False),
756756
('[Eq(u, 2*u)]', [0, 2, 0, 0], False),
757757
('[Eq(u, 2*u + f)]', [0, 3, 0, 0, 0, 0, 0], True),
758-
('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 0], False)
758+
('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 2, 0, 0], False)
759759
])
760760
def test_collapsing(self, eqns, expected, blocking):
761761
grid = Grid(shape=(3, 3, 3))

tests/test_mpi.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import pytest
5-
from test_dse import TestTTI
65

76
from conftest import _R, assert_blocking, assert_structure, body0
87
from devito import (
@@ -2203,6 +2202,25 @@ def test_lift_halo_update_outside_distributed(self, mode):
22032202
halo_update = tloop.nodes[0].body[0].body[0].body[0]
22042203
assert isinstance(halo_update, HaloUpdateList)
22052204

2205+
@pytest.mark.parallel(mode=4)
2206+
def test_halo_inner_dim(self, mode):
2207+
grid = Grid((11, 11, 11))
2208+
2209+
np.random.seed(0)
2210+
v = TimeFunction(name="v", grid=grid, space_order=4,
2211+
time_order=1, save=Buffer(1))
2212+
v.data[:] = np.random.randn(*grid.shape)
2213+
e = TimeFunction(name="dummy", grid=grid, space_order=4, time_order=0)
2214+
2215+
eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)]
2216+
2217+
op = Operator(eq, opt=('advanced', {'blocklevels': 0}))
2218+
2219+
assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz')
2220+
op(time=100)
2221+
2222+
assert np.isclose(norm(e), 23484.863, rtol=0, atol=1e-1)
2223+
22062224

22072225
class TestOperatorAdvanced:
22082226

@@ -2734,9 +2752,10 @@ def test_haloupdate_same_timestep_v2(self, mode):
27342752

27352753
titer = op.body.body[-1].body[0]
27362754
assert titer.dim is grid.time_dim
2737-
assert titer.nodes[0].body[0].body[0].is_List
2738-
assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1
2739-
assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
2755+
block = titer.nodes[0].body[0].body[1]
2756+
assert block.is_List
2757+
assert len(block.body) == 3
2758+
assert block.body[0].body[0].is_Call
27402759

27412760
op.apply(time=0)
27422761

@@ -3139,7 +3158,7 @@ def test_fission_due_to_antidep(self, mode):
31393158
assert_structure(op1, ['t',
31403159
't,x0_blk0,y0_blk0,x,y,z',
31413160
't,x0_blk0,y0_blk0,x,y,z'],
3142-
't,x0_blk0,y0_blk0,x,y,z,z')
3161+
'tx0_blk0y0_blk0xyzz')
31433162

31443163
def init(f, v=1):
31453164
f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01
@@ -3513,9 +3532,9 @@ def test_issue_2448_backward(self, mode):
35133532

35143533
class TestTTIOp:
35153534

3516-
@pytest.mark.skipif(TestTTI is None, reason="Requires installing the tests")
35173535
@pytest.mark.parallel(mode=1)
35183536
def test_halo_structure(self, mode):
3537+
from test_dse import TestTTI
35193538
solver = TestTTI().tti_operator(opt='advanced', space_order=8)
35203539
op = solver.op_fwd(save=False)
35213540

tests/test_operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,12 +1582,12 @@ def test_no_fission_as_illegal(self, exprs):
15821582
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15831583
'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',
15841584
'Eq(tu[t,x,y,0], tu[t,x,y,0] + 1.)'),
1585-
'+++++', ['txyz', 'txyz', 'txy'], 'txyzz'),
1585+
'+++++++', ['txyz', 'txyz', 'txy'], 'txyzxyz'),
15861586
# 7) WAR 1->2, 2->3
15871587
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15881588
'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',
15891589
'Eq(tw[t,x,y,z], tv[t,x,y,z-1] + 1.)'),
1590-
'++++++', ['txyz', 'txyz', 'txyz'], 'txyzzz'),
1590+
'++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzz'),
15911591
# 8) WAR 1->2; WAW 1->3
15921592
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15931593
'Eq(tv[t,x,y,z], tu[t,x+2,y,z])',
@@ -1597,7 +1597,7 @@ def test_no_fission_as_illegal(self, exprs):
15971597
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15981598
'Eq(tv[t,x,y,z], tu[t,x,y,z-2])',
15991599
'Eq(tw[t,x,y,z], tv[t,x,y+1,z] + 1.)'),
1600-
'+++++++', ['txyz', 'txyz', 'txyz'], 'txyzzyz'),
1600+
'+++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzyz'),
16011601
# 10) WAR 1->2; WAW 1->3
16021602
(('Eq(tu[t-1,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
16031603
'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',

0 commit comments

Comments
 (0)