Skip to content

Commit cf14f16

Browse files
Merge pull request #2890 from devitocodes/minmax-element
dsl: Add ReduceMinMax construct for joint minmax reductions
2 parents db3225e + 433e7c0 commit cf14f16

7 files changed

Lines changed: 135 additions & 25 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.ir.clusters.analysis import analyze
1111
from devito.ir.clusters.cluster import Cluster, ClusterGroup
1212
from devito.ir.clusters.visitors import Queue, cluster_pass
13-
from devito.ir.equations import OpMax, OpMin, identity_mapper
13+
from devito.ir.equations import OpMax, OpMin, OpMinMax, identity_mapper
1414
from devito.ir.support import (
1515
Any, Backward, Forward, IterationSpace, Scope, erange, pull_dims
1616
)
@@ -531,8 +531,19 @@ def _update(reductions):
531531
# The IterationSpace within which the global distributed reduction
532532
# must be carried out
533533
ispace = c.ispace.prefix(lambda d: d in var.free_symbols) # noqa: B023
534-
expr = [Eq(var, DistReduce(var, op=op, grid=grid, ispace=ispace))]
535-
fifo.append(c.rebuild(exprs=expr, ispace=ispace))
534+
535+
if op is OpMinMax:
536+
# MinMax not natively supported by MPI, so for now we perform two
537+
# separate reductions (not optimal, but it will do for now)
538+
var0, var1 = var, var._translate()
539+
exprs = [
540+
Eq(var0, DistReduce(var0, op=OpMin, grid=grid, ispace=ispace)),
541+
Eq(var1, DistReduce(var1, op=OpMax, grid=grid, ispace=ispace))
542+
]
543+
else:
544+
exprs = [Eq(var, DistReduce(var, op=op, grid=grid, ispace=ispace))]
545+
546+
fifo.append(c.rebuild(exprs=exprs, ispace=ispace))
536547

537548
processed.append(c)
538549

@@ -547,7 +558,7 @@ def normalize(clusters, sregistry=None, options=None, platform=None, **kwargs):
547558
if options['mapify-reduce']:
548559
clusters = normalize_reductions_dense(clusters, sregistry, platform)
549560
else:
550-
clusters = normalize_reductions_minmax(clusters)
561+
clusters = normalize_reductions_minmax(clusters, sregistry)
551562
clusters = normalize_reductions_sparse(clusters, sregistry)
552563

553564
return clusters
@@ -591,7 +602,7 @@ def pull_indexeds(expr, subs, mapper, parent=None):
591602

592603

593604
@cluster_pass(mode='dense')
594-
def normalize_reductions_minmax(cluster):
605+
def normalize_reductions_minmax(cluster, sregistry):
595606
"""
596607
Initialize the reduction variables to their neutral element and use them
597608
to compute the reduction.
@@ -603,6 +614,7 @@ def normalize_reductions_minmax(cluster):
603614

604615
init = []
605616
processed = []
617+
post = []
606618
for e in cluster.exprs:
607619
lhs, rhs = e.args
608620
f = lhs.function
@@ -623,10 +635,32 @@ def normalize_reductions_minmax(cluster):
623635

624636
processed.append(e.func(lhs, Max(lhs, rhs)))
625637

638+
elif e.operation is OpMinMax:
639+
# NOTE: we need to create two different reduction variables here
640+
# (instead of using say `n[0]` and `n[1]` directly) because that's
641+
# essentially what OpenMP/OpenACC expect -- two different symbols
642+
rmin = Symbol(name=sregistry.make_name(prefix='rmin'), dtype=lhs.dtype)
643+
rmax = Symbol(name=sregistry.make_name(prefix='rmax'), dtype=lhs.dtype)
644+
645+
expr0 = Eq(rmin, limits_mapper[lhs.dtype].max)
646+
expr1 = Eq(rmax, limits_mapper[lhs.dtype].min)
647+
ispace = cluster.ispace.project(lambda i: i not in dims)
648+
init.append(cluster.rebuild(exprs=[expr0, expr1], ispace=ispace))
649+
650+
processed.extend([
651+
e.func(rmin, Min(rmin, rhs), operation=OpMin),
652+
e.func(rmax, Max(rmax, rhs), operation=OpMax)
653+
])
654+
655+
# Copy-back the final result to `lhs` at the end of the reduction
656+
expr0 = Eq(lhs, rmin)
657+
expr1 = Eq(lhs._translate(), rmax)
658+
post.append(cluster.rebuild(exprs=[expr0, expr1], ispace=ispace))
659+
626660
else:
627661
processed.append(e)
628662

629-
return init + [cluster.rebuild(processed)]
663+
return init + [cluster.rebuild(processed)] + post
630664

631665

632666
def normalize_reductions_dense(cluster, sregistry, platform):
@@ -674,19 +708,20 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform):
674708
if e.is_Reduction:
675709
lhs, rhs = e.args
676710

711+
wf = lhs.function
677712
try:
678-
f = rhs.function
713+
rf = rhs.function
679714
except AttributeError:
680-
f = None
715+
rf = None
681716

682-
if lhs.function.is_Array:
717+
if wf.is_Array and set(candidates).intersection(wf.dimensions):
683718
# Probably a compiler-generated reduction, e.g. via
684719
# recursive compilation; it's an Array already, so nothing to do
685720
processed.append(e)
686721
elif rhs in mapper:
687722
# Seen this RHS already, so reuse the Array that was created for it
688723
processed.append(e.func(lhs, mapper[rhs].indexify()))
689-
elif f and f.is_Array and sum(flatten(f._size_nodomain)) == 0:
724+
elif rf and rf.is_Array and sum(flatten(rf._size_nodomain)) == 0:
690725
# Special case: the RHS is an Array with no halo/padding, meaning
691726
# that the written data values are contiguous in memory, hence
692727
# we can simply reuse the Array itself as we're already in the
@@ -698,8 +733,9 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform):
698733
grid = cluster.grid
699734
except ValueError:
700735
grid = None
701-
a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims,
702-
grid=grid)
736+
a = mapper[rhs] = Array(
737+
name=name, dtype=e.dtype, dimensions=dims, grid=grid
738+
)
703739

704740
# Populate the Array (the "map" part)
705741
processed.append(e.func(a.indexify(), rhs, operation=None))

devito/ir/equations/equation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from devito.symbolics import IntDiv, limits_mapper, uxreplace
1313
from devito.tools import Pickable, Tag, frozendict
14-
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
14+
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min
1515

1616
__all__ = [
1717
'ClusterizedEq',
@@ -20,6 +20,7 @@
2020
'OpInc',
2121
'OpMax',
2222
'OpMin',
23+
'OpMinMax',
2324
'identity_mapper',
2425
]
2526

@@ -69,7 +70,7 @@ def operation(self):
6970

7071
@property
7172
def is_Reduction(self):
72-
return self.operation in (OpInc, OpMin, OpMax)
73+
return self.operation in (OpInc, OpMin, OpMax, OpMinMax)
7374

7475
@property
7576
def is_Increment(self):
@@ -113,7 +114,8 @@ def detect(cls, expr):
113114
reduction_mapper = {
114115
Inc: OpInc,
115116
ReduceMax: OpMax,
116-
ReduceMin: OpMin
117+
ReduceMin: OpMin,
118+
ReduceMinMax: OpMinMax
117119
}
118120
try:
119121
return reduction_mapper[type(expr)]
@@ -130,6 +132,7 @@ def detect(cls, expr):
130132
OpInc = Operation('+')
131133
OpMax = Operation('max')
132134
OpMin = Operation('min')
135+
OpMinMax = Operation('minmax')
133136

134137

135138
identity_mapper = {

devito/ir/iet/nodes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from devito.data import FULL
1515
from devito.ir.cgen import ccode
16-
from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin
16+
from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin, OpMinMax
1717
from devito.ir.support import (
1818
AFFINE, INBOUND, PARALLEL, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT, SEQUENTIAL,
1919
VECTORIZED, Forward, PrefetchUpdate, Property, WithLock, detect_io
@@ -457,7 +457,7 @@ def reads(self):
457457
@cached_property
458458
def write(self):
459459
"""The Function written by the Expression."""
460-
return self.expr.lhs.base.function
460+
return self.output.base.function
461461

462462
@cached_property
463463
def dimensions(self):
@@ -467,17 +467,17 @@ def dimensions(self):
467467
@property
468468
def is_scalar(self):
469469
"""True if the LHS is a scalar, False otherwise."""
470-
return isinstance(self.expr.lhs, (AbstractSymbol, IndexedBase, LocalObject))
470+
return isinstance(self.output, (AbstractSymbol, IndexedBase, LocalObject))
471471

472472
@property
473473
def is_tensor(self):
474474
"""True if the LHS is an array entry, False otherwise."""
475-
return self.expr.lhs.is_Indexed
475+
return self.output.is_Indexed
476476

477477
@property
478478
def is_reduction(self):
479479
"""True if the RHS performs a reduction operation, False otherwise."""
480-
return self.operation in (OpInc, OpMin, OpMax)
480+
return self.operation in (OpInc, OpMin, OpMax, OpMinMax)
481481

482482
@property
483483
def is_initializable(self):

devito/passes/iet/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search
1818
from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass
1919
from devito.types import (
20-
Array, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension,
20+
Array, Auto, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension,
2121
Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol,
2222
Temp, ThreadArray, Wildcard
2323
)
@@ -658,6 +658,7 @@ def _(i, mapper, sregistry):
658658
mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ind'))
659659

660660

661+
@abstract_object.register(Auto)
661662
@abstract_object.register(Temp)
662663
@abstract_object.register(Wildcard)
663664
def _(i, mapper, sregistry):

devito/types/basic.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,31 @@ def _subs(self, old, new, **hints):
19261926
pass
19271927
return super()._subs(old, new, **hints)
19281928

1929+
def _translate(self, mapper=None):
1930+
"""
1931+
Translate the indices of the current Indexed according to the provided
1932+
`{Dimension -> offset}` mapper. For example, if the current Indexed is
1933+
`f[x+1]` and the mapper is `{x: -1}`, then the result of the translation
1934+
will be `f[x]`.
1935+
1936+
If `mapper` is None, then the translation will be unitary increment
1937+
along the fastest varying Dimension. For example, if the current
1938+
Indexed is `f[x+1, y+2]`, then the result of the translation will be
1939+
`f[x+1, y+3]` since `x` is the fastest varying Dimension.
1940+
"""
1941+
mapper = mapper or {self.dimensions[-1]: 1}
1942+
1943+
if any(d not in mapper for d in self.dimensions):
1944+
raise ValueError(
1945+
f"Cannot translate {self} with mapper {mapper} since not "
1946+
"all dimensions are covered"
1947+
)
1948+
1949+
translations = [mapper.get(d, 0) for d in self.dimensions]
1950+
indices = [sum(i) for i in zip(self.indices, translations, strict=True)]
1951+
1952+
return self.base[indices]
1953+
19291954

19301955
class IrregularFunctionInterface:
19311956

devito/types/equation.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from devito.tools import Pickable, as_tuple, frozendict
88
from devito.types.lazy import Evaluable
99

10-
__all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin']
10+
__all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin', 'ReduceMinMax']
1111

1212

1313
class Eq(sympy.Eq, Evaluable, Pickable):
@@ -62,8 +62,8 @@ class Eq(sympy.Eq, Evaluable, Pickable):
6262
__rargs__ = ('lhs', 'rhs')
6363
__rkwargs__ = ('subdomain', 'coefficients', 'implicit_dims')
6464

65-
def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None,
66-
**kwargs):
65+
def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None,
66+
implicit_dims=None, **kwargs):
6767
if coefficients is not None:
6868
_ = deprecations.coeff_warn
6969
kwargs['evaluate'] = False
@@ -237,3 +237,23 @@ class ReduceMax(Reduction):
237237

238238
class ReduceMin(Reduction):
239239
pass
240+
241+
242+
class ReduceMinMax(Reduction):
243+
244+
"""
245+
A coupled min/max Reduction.
246+
247+
The left-hand side must have room for two components, one for the minimum and
248+
one for the maximum; the behaviour is otherwise undefined.
249+
The right-hand side is the expression to be reduced.
250+
"""
251+
252+
def __new__(cls, lhs, rhs=0, **kwargs):
253+
if not lhs.function.is_AbstractFunction:
254+
raise ValueError(
255+
f"The left-hand side of a {cls.__name__} must be a "
256+
"Function of size at least 2"
257+
)
258+
259+
return super().__new__(cls, lhs, rhs=rhs, **kwargs)

tests/test_dle.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from conftest import _R, assert_blocking, assert_structure, skipif
99
from devito import (
1010
CustomDimension, DefaultDimension, Dimension, Eq, Function, Grid, Inc, Operator,
11-
PrecomputedSparseTimeFunction, ReduceMax, ReduceMin, SpaceDimension,
11+
PrecomputedSparseTimeFunction, ReduceMax, ReduceMin, ReduceMinMax, SpaceDimension,
1212
SparseTimeFunction, SubDimension, TimeFunction, configuration, cos, dimensions, info
1313
)
1414
from devito.exceptions import InvalidArgument
@@ -999,6 +999,31 @@ def test_array_minmax_reduction(self):
999999
assert n.data[0] == 26
10001000
assert n.data[1] == 0
10011001

1002+
def test_array_minmax_reduction_simultaneous(self):
1003+
"""
1004+
Test the combined min/max reduction DSL construct.
1005+
"""
1006+
grid = Grid(shape=(3, 3, 3))
1007+
i = Dimension(name='i')
1008+
1009+
f = Function(name='f', grid=grid)
1010+
n = Function(name='n', grid=grid, shape=(2,), dimensions=(i,))
1011+
1012+
f.data[:] = np.arange(-5, 22).reshape((3, 3, 3))
1013+
1014+
eqn = [ReduceMinMax(n[0], f)]
1015+
1016+
op = Operator(eqn)
1017+
1018+
if 'openmp' in configuration['language']:
1019+
iterations = FindNodes(Iteration).visit(op)
1020+
expected = "reduction(min:rmin0) reduction(max:rmax0)"
1021+
assert expected in iterations[0].pragmas[0].ccode.value
1022+
1023+
op()
1024+
assert n.data[0] == -5
1025+
assert n.data[1] == 21
1026+
10021027
def test_incs_no_atomic(self):
10031028
"""
10041029
Test that `Inc`'s don't get a `#pragma omp atomic` if performing

0 commit comments

Comments
 (0)