Skip to content

Commit bab9aae

Browse files
authored
Merge pull request #2891 from devitocodes/hotfix-efunc-abstract
compiler: Fix efunc abstraction for MinMax pragma-based kernels
2 parents cf14f16 + 3228ffb commit bab9aae

2 files changed

Lines changed: 20 additions & 4 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from devito.tools import (
2121
DefaultOrderedDict, Stamp, as_mapper, flatten, is_integer, split, timed_pass, toposort
2222
)
23-
from devito.types import Array, Eq, Symbol
23+
from devito.types import Array, Eq, Symbol, Temp
2424
from devito.types.dimension import BOTTOM, ModuloDimension
2525

2626
__all__ = ['clusterize']
@@ -639,8 +639,8 @@ def normalize_reductions_minmax(cluster, sregistry):
639639
# NOTE: we need to create two different reduction variables here
640640
# (instead of using say `n[0]` and `n[1]` directly) because that's
641641
# essentially what OpenMP/OpenACC expect -- two different symbols
642-
rmin = Symbol(name=sregistry.make_name(prefix='rmin'), dtype=lhs.dtype)
643-
rmax = Symbol(name=sregistry.make_name(prefix='rmax'), dtype=lhs.dtype)
642+
rmin = Temp(name=sregistry.make_name(prefix='rmin'), dtype=lhs.dtype)
643+
rmax = Temp(name=sregistry.make_name(prefix='rmax'), dtype=lhs.dtype)
644644

645645
expr0 = Eq(rmin, limits_mapper[lhs.dtype].max)
646646
expr1 = Eq(rmax, limits_mapper[lhs.dtype].min)

devito/ir/iet/visitors.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ class Uxreplace(Transformer):
13771377
def visit_Expression(self, o):
13781378
return o._rebuild(expr=uxreplace(o.expr, self.mapper))
13791379

1380-
def visit_Iteration(self, o):
1380+
def _visit_Iteration_common(self, o):
13811381
nodes = self._visit(o.nodes)
13821382
dimension = uxreplace(o.dim, self.mapper)
13831383
limits = [uxreplace(i, self.mapper) for i in o.limits]
@@ -1386,9 +1386,25 @@ def visit_Iteration(self, o):
13861386
uindices = [uxreplace(i, self.mapper) for i in o.uindices]
13871387
uindices = filter_ordered(i for i in uindices if isinstance(i, Dimension))
13881388

1389+
return nodes, dimension, limits, pragmas, uindices
1390+
1391+
def visit_Iteration(self, o):
1392+
nodes, dimension, limits, pragmas, uindices = \
1393+
self._visit_Iteration_common(o)
1394+
13891395
return o._rebuild(nodes=nodes, dimension=dimension, limits=limits,
13901396
pragmas=pragmas, uindices=uindices)
13911397

1398+
def visit_PragmaIteration(self, o):
1399+
nodes, dimension, limits, pragmas, uindices = \
1400+
self._visit_Iteration_common(o)
1401+
1402+
reduction = [(uxreplace(var, self.mapper), imask, op)
1403+
for var, imask, op in (o.reduction or [])]
1404+
1405+
return o._rebuild(nodes=nodes, dimension=dimension, limits=limits,
1406+
pragmas=pragmas, uindices=uindices, reduction=reduction)
1407+
13921408
def visit_Definition(self, o):
13931409
try:
13941410
return o._rebuild(function=self.mapper[o.function])

0 commit comments

Comments
 (0)