1010from devito .ir .clusters .analysis import analyze
1111from devito .ir .clusters .cluster import Cluster , ClusterGroup
1212from devito .ir .clusters .visitors import Queue , cluster_pass
13- from devito .ir .equations import OpMax , OpMin , identity_mapper
13+ from devito .ir .equations import OpMax , OpMin , OpMinMax , identity_mapper
1414from devito .ir .support import (
1515 Any , Backward , Forward , IterationSpace , Scope , erange , pull_dims
1616)
@@ -531,8 +531,19 @@ def _update(reductions):
531531 # The IterationSpace within which the global distributed reduction
532532 # must be carried out
533533 ispace = c .ispace .prefix (lambda d : d in var .free_symbols ) # noqa: B023
534- expr = [Eq (var , DistReduce (var , op = op , grid = grid , ispace = ispace ))]
535- fifo .append (c .rebuild (exprs = expr , ispace = ispace ))
534+
535+ if op is OpMinMax :
536+ # MinMax not natively supported by MPI, so for now we perform two
537+ # separate reductions (not optimal, but it will do for now)
538+ var0 , var1 = var , var ._translate ()
539+ exprs = [
540+ Eq (var0 , DistReduce (var0 , op = OpMin , grid = grid , ispace = ispace )),
541+ Eq (var1 , DistReduce (var1 , op = OpMax , grid = grid , ispace = ispace ))
542+ ]
543+ else :
544+ exprs = [Eq (var , DistReduce (var , op = op , grid = grid , ispace = ispace ))]
545+
546+ fifo .append (c .rebuild (exprs = exprs , ispace = ispace ))
536547
537548 processed .append (c )
538549
@@ -547,7 +558,7 @@ def normalize(clusters, sregistry=None, options=None, platform=None, **kwargs):
547558 if options ['mapify-reduce' ]:
548559 clusters = normalize_reductions_dense (clusters , sregistry , platform )
549560 else :
550- clusters = normalize_reductions_minmax (clusters )
561+ clusters = normalize_reductions_minmax (clusters , sregistry )
551562 clusters = normalize_reductions_sparse (clusters , sregistry )
552563
553564 return clusters
@@ -591,7 +602,7 @@ def pull_indexeds(expr, subs, mapper, parent=None):
591602
592603
593604@cluster_pass (mode = 'dense' )
594- def normalize_reductions_minmax (cluster ):
605+ def normalize_reductions_minmax (cluster , sregistry ):
595606 """
596607 Initialize the reduction variables to their neutral element and use them
597608 to compute the reduction.
@@ -603,6 +614,7 @@ def normalize_reductions_minmax(cluster):
603614
604615 init = []
605616 processed = []
617+ post = []
606618 for e in cluster .exprs :
607619 lhs , rhs = e .args
608620 f = lhs .function
@@ -623,10 +635,32 @@ def normalize_reductions_minmax(cluster):
623635
624636 processed .append (e .func (lhs , Max (lhs , rhs )))
625637
638+ elif e .operation is OpMinMax :
639+ # NOTE: we need to create two different reduction variables here
640+ # (instead of using say `n[0]` and `n[1]` directly) because that's
641+ # essentially what OpenMP/OpenACC expect -- two different symbols
642+ rmin = Symbol (name = sregistry .make_name (prefix = 'rmin' ), dtype = lhs .dtype )
643+ rmax = Symbol (name = sregistry .make_name (prefix = 'rmax' ), dtype = lhs .dtype )
644+
645+ expr0 = Eq (rmin , limits_mapper [lhs .dtype ].max )
646+ expr1 = Eq (rmax , limits_mapper [lhs .dtype ].min )
647+ ispace = cluster .ispace .project (lambda i : i not in dims )
648+ init .append (cluster .rebuild (exprs = [expr0 , expr1 ], ispace = ispace ))
649+
650+ processed .extend ([
651+ e .func (rmin , Min (rmin , rhs ), operation = OpMin ),
652+ e .func (rmax , Max (rmax , rhs ), operation = OpMax )
653+ ])
654+
655+ # Copy-back the final result to `lhs` at the end of the reduction
656+ expr0 = Eq (lhs , rmin )
657+ expr1 = Eq (lhs ._translate (), rmax )
658+ post .append (cluster .rebuild (exprs = [expr0 , expr1 ], ispace = ispace ))
659+
626660 else :
627661 processed .append (e )
628662
629- return init + [cluster .rebuild (processed )]
663+ return init + [cluster .rebuild (processed )] + post
630664
631665
632666def normalize_reductions_dense (cluster , sregistry , platform ):
@@ -674,19 +708,20 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform):
674708 if e .is_Reduction :
675709 lhs , rhs = e .args
676710
711+ wf = lhs .function
677712 try :
678- f = rhs .function
713+ rf = rhs .function
679714 except AttributeError :
680- f = None
715+ rf = None
681716
682- if lhs . function . is_Array :
717+ if wf . is_Array and set ( candidates ). intersection ( wf . dimensions ) :
683718 # Probably a compiler-generated reduction, e.g. via
684719 # recursive compilation; it's an Array already, so nothing to do
685720 processed .append (e )
686721 elif rhs in mapper :
687722 # Seen this RHS already, so reuse the Array that was created for it
688723 processed .append (e .func (lhs , mapper [rhs ].indexify ()))
689- elif f and f .is_Array and sum (flatten (f ._size_nodomain )) == 0 :
724+ elif rf and rf .is_Array and sum (flatten (rf ._size_nodomain )) == 0 :
690725 # Special case: the RHS is an Array with no halo/padding, meaning
691726 # that the written data values are contiguous in memory, hence
692727 # we can simply reuse the Array itself as we're already in the
@@ -698,8 +733,9 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform):
698733 grid = cluster .grid
699734 except ValueError :
700735 grid = None
701- a = mapper [rhs ] = Array (name = name , dtype = e .dtype , dimensions = dims ,
702- grid = grid )
736+ a = mapper [rhs ] = Array (
737+ name = name , dtype = e .dtype , dimensions = dims , grid = grid
738+ )
703739
704740 # Populate the Array (the "map" part)
705741 processed .append (e .func (a .indexify (), rhs , operation = None ))
0 commit comments