Skip to content

Commit 076af37

Browse files
committed
compiler: fix terminal detection from dynamic dim bounds
1 parent 013bb51 commit 076af37

5 files changed

Lines changed: 35 additions & 32 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from operator import attrgetter
66

77
import sympy
8-
from sympy import Max, Min, S
8+
from sympy import Max, Min
99

1010
from devito import configuration
1111
from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT
@@ -514,31 +514,6 @@ def merge(self, hs):
514514
fmapper[f] = fmapper.get(f, hse).merge(hse)
515515
return HaloScheme.build(fmapper, self.honored)
516516

517-
def _is_iter_carried(self, scope):
518-
"""
519-
True if the HaloScheme is iteration-carried, i.e., it induces
520-
a halo exchange that requires values from the previous iteration(s); False
521-
otherwise.
522-
"""
523-
524-
def rule0(dep):
525-
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>`, `d=t` => OK
526-
return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause)
527-
528-
def rule1(dep, loc_indices):
529-
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>`, `loc_indices={t: t0}` => OK
530-
return any(dep.distance_mapper[d] == 0 and
531-
dep.source[d] is not v and
532-
dep.sink[d] is not v
533-
for d, v in loc_indices.items())
534-
535-
for f, v in self.fmapper.items():
536-
for dep in scope.d_flow.project(f):
537-
if not rule0(dep) and not rule1(dep, v.loc_indices):
538-
return False
539-
540-
return True
541-
542517

543518
def classify(exprs, ispace):
544519
"""

devito/passes/iet/mpi.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _hoist_invariant(iet):
206206
# Ensure there's another HaloScheme that could cover for
207207
# us should we get hoisted while still satisfying the
208208
# data dependences
209-
if hsf1.issubset(hsf0) and hsf1._is_iter_carried(scope):
209+
if hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope):
210210
hs, hsf = hs1, hsf1
211211
elif hsf0.issubset(hsf1) and hs0 is halo_spots[0]:
212212
# Special case
@@ -474,6 +474,32 @@ def _check_control_flow(hs0, hs1, cond_mapper):
474474
return cond0 != cond1
475475

476476

477+
def _is_iter_carried(hsf, scope):
478+
"""
479+
True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces
480+
a halo exchange that requires values from the previous iteration(s); False
481+
otherwise.
482+
"""
483+
484+
def rule0(dep):
485+
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>`, `d=t` => OK
486+
return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause)
487+
488+
def rule1(dep, loc_indices):
489+
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>`, `loc_indices={t: t0}` => OK
490+
return any(dep.distance_mapper[d] == 0 and
491+
dep.source[d] is not v and
492+
dep.sink[d] is not v
493+
for d, v in loc_indices.items())
494+
495+
for f, v in hsf.fmapper.items():
496+
for dep in scope.d_flow.project(f):
497+
if not rule0(dep) and not rule1(dep, v.loc_indices):
498+
return False
499+
500+
return True
501+
502+
477503
def _is_mergeable(hsf0, hsf1, scope):
478504
"""
479505
True if `hsf1` can be merged into `hsf0`, i.e., if they are compatible
@@ -489,7 +515,7 @@ def _is_mergeable(hsf0, hsf1, scope):
489515
return False
490516

491517
# Finally, check the data dependences would be satisfied
492-
return hsf1._is_iter_carried(scope)
518+
return _is_iter_carried(hsf1, scope)
493519

494520

495521
def _semantical_eq_loc_indices(hsf0, hsf1):

devito/symbolics/search.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> N
6060
def _next(self, expr: Expression) -> Iterable[Expression]:
6161
if self.deep and expr.is_Indexed:
6262
return expr.indices
63+
elif self.deep and q_dimension(expr):
64+
return expr.bound_symbols
6365
elif q_leaf(expr):
6466
return ()
6567
return expr.args

tests/test_dle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def test_dynamic_nthreads(self):
755755
('[Eq(f, 2*f)]', [2, 0, 0], False),
756756
('[Eq(u, 2*u)]', [0, 2, 0, 0], False),
757757
('[Eq(u, 2*u + f)]', [0, 3, 0, 0, 0, 0, 0], True),
758-
('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 0], False)
758+
('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 2, 0, 0], False)
759759
])
760760
def test_collapsing(self, eqns, expected, blocking):
761761
grid = Grid(shape=(3, 3, 3))

tests/test_operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,12 +1582,12 @@ def test_no_fission_as_illegal(self, exprs):
15821582
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15831583
'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',
15841584
'Eq(tu[t,x,y,0], tu[t,x,y,0] + 1.)'),
1585-
'+++++', ['txyz', 'txyz', 'txy'], 'txyzz'),
1585+
'+++++++', ['txyz', 'txyz', 'txy'], 'txyzxyz'),
15861586
# 7) WAR 1->2, 2->3
15871587
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15881588
'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',
15891589
'Eq(tw[t,x,y,z], tv[t,x,y,z-1] + 1.)'),
1590-
'++++++', ['txyz', 'txyz', 'txyz'], 'txyzzz'),
1590+
'++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzz'),
15911591
# 8) WAR 1->2; WAW 1->3
15921592
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15931593
'Eq(tv[t,x,y,z], tu[t,x+2,y,z])',
@@ -1597,7 +1597,7 @@ def test_no_fission_as_illegal(self, exprs):
15971597
(('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
15981598
'Eq(tv[t,x,y,z], tu[t,x,y,z-2])',
15991599
'Eq(tw[t,x,y,z], tv[t,x,y+1,z] + 1.)'),
1600-
'+++++++', ['txyz', 'txyz', 'txyz'], 'txyzzyz'),
1600+
'+++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzyz'),
16011601
# 10) WAR 1->2; WAW 1->3
16021602
(('Eq(tu[t-1,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])',
16031603
'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',

0 commit comments

Comments
 (0)