Skip to content

Commit cc4fe0d

Browse files
Merge pull request #2886 from devitocodes/tma-write
compiler: Enhance detect_accesses and patch symbolic padding
2 parents bd39274 + fc4c548 commit cc4fe0d

8 files changed

Lines changed: 167 additions & 14 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ def _print_SafeInv(self, expr):
220220
val = self._print(expr.val)
221221
return f'SAFEINV({val}, {base})'
222222

223+
def _print_RoundUp(self, expr):
224+
value = self._print(expr.value)
225+
step = self._print(expr.step)
226+
return f'ROUND_UP({value}, {step})'
227+
223228
def _print_Mod(self, expr):
224229
"""Print a Mod as a C-like %-based operation."""
225230
args = [f'({self._print(a)})' for a in expr.args]

devito/ir/support/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search
77
from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split
88
from devito.types import (
9-
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension
9+
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove
1010
)
1111

1212
__all__ = [
@@ -137,7 +137,14 @@ def detect_accesses(exprs):
137137
"""
138138
# Compute M : F -> S
139139
mapper = defaultdict(Stencil)
140-
for e in retrieve_indexed(exprs, deep=True):
140+
141+
# Search among the Indexeds (Most accesses typically stem from Indexeds)
142+
plain_indexeds = retrieve_indexed(exprs, deep=True)
143+
144+
# Search among higher order objects, which still represent meaningful accesses
145+
high_order_indexeds = [i.indexed for i in search(exprs, TensorMove)]
146+
147+
for e in (*plain_indexeds, *high_order_indexeds):
141148
f = e.function
142149

143150
for a, d0 in zip(e.indices, f.dimensions, strict=False):
@@ -164,13 +171,16 @@ def detect_accesses(exprs):
164171
d, others = split(dims, lambda i: d0 in i._defines) # noqa: B023
165172

166173
if any(i.is_Indexed for i in a.args) or len(d) != 1:
167-
# Case 1) -- with indirect accesses there's not much we can infer
174+
# Case 1) -- with indirect accesses there's not much we
175+
# can infer
168176
continue
169177
else:
170178
# Case 2)
171179
d, = d
172180
_, o = split(others, lambda i: i.is_Custom)
173-
off = sum(i for i in a.args if i.is_integer or i.free_symbols & o)
181+
off = sum(
182+
i for i in a.args if i.is_integer or i.free_symbols & o
183+
)
174184
else:
175185
d, = dims
176186

devito/passes/iet/misc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
1515
from devito.passes.iet.engine import iet_pass
1616
from devito.passes.iet.languages.C import CPrinter
17-
from devito.symbolics import Cast, ValueLimit, evalrel, has_integer_args, limits_mapper
17+
from devito.symbolics import (
18+
Cast, RoundUp, ValueLimit, evalrel, has_integer_args, limits_mapper
19+
)
1820
from devito.tools import Bunch, as_mapper, as_tuple, filter_ordered, split
1921
from devito.types import FIndexed
2022

@@ -255,6 +257,12 @@ def _(expr, langbb, printer):
255257
f'(0.0{ext}) : ((1.0{ext}) / (a)))'),), {}
256258

257259

260+
@_lower_macro_math.register(RoundUp)
261+
def _(expr, langbb, printer):
262+
return (('ROUND_UP(a,b)',
263+
'((((a)%(b)) == 0) ? (a) : ((a) + (b) - ((a)%(b))))'),), {}
264+
265+
258266
@iet_pass
259267
def minimize_symbols(iet):
260268
"""

devito/symbolics/extended_sympy.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite',
2424
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
2525
'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword',
26-
'String', 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace',
27-
'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit',
28-
'VectorAccess']
26+
'String', 'Macro', 'Class', 'MacroArgument', 'RoundUp', 'Deref',
27+
'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin',
28+
'ValueLimit', 'VectorAccess']
2929

3030

3131
class CondEq(sympy.Eq):
@@ -623,6 +623,49 @@ def __str__(self):
623623
__repr__ = __str__
624624

625625

626+
class RoundUp(Function):
627+
628+
"""
629+
Symbolic representation of rounding a value up to the next multiple of a
630+
given step.
631+
"""
632+
633+
def __new__(cls, value, step, **kwargs):
634+
value = sympify(value)
635+
step = sympify(step)
636+
637+
if step < 1:
638+
raise ValueError("Cannot round up with negative `step`")
639+
if not is_integer(step):
640+
raise ValueError("`step` must be an integer")
641+
642+
if value.is_number and step.is_number:
643+
remainder = value % step
644+
if remainder == 0:
645+
return value
646+
else:
647+
return value + step - remainder
648+
649+
return super().__new__(cls, value, step, **kwargs)
650+
651+
@property
652+
def value(self):
653+
return self.args[0]
654+
655+
@property
656+
def step(self):
657+
return self.args[1]
658+
659+
@property
660+
def is_commutative(self):
661+
return self.value.is_commutative and self.step.is_commutative
662+
663+
def __str__(self):
664+
return f"ROUND_UP({self.value}, {self.step})"
665+
666+
__repr__ = __str__
667+
668+
626669
class ValueLimit(ReservedWord, sympy.Expr):
627670

628671
"""

devito/types/basic.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,12 +921,17 @@ def __padding_setup_smart__(self, **kwargs):
921921
return nopadding
922922

923923
mmts = configuration['platform'].max_mem_trans_size(self.__padding_dtype__)
924-
remainder = self._size_nopad[d] % mmts
924+
925+
snp = self._size_nopad[d]
926+
remainder = snp % mmts
925927
if remainder == 0:
926928
# Already a multiple of `mmts`, no need to pad
927929
return nopadding
930+
else:
931+
from devito.symbolics import RoundUp # noqa
932+
v = RoundUp(snp, mmts) - snp
928933

929-
dpadding = (0, (mmts - remainder))
934+
dpadding = (0, v)
930935
padding = [(0, 0)]*self.ndim
931936
padding[self.dimensions.index(d)] = dpadding
932937

devito/types/parallel.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,47 @@ class TensorMove(Expr, Reserved, Terminal):
407407

408408
"""
409409
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
410-
level of the memory hierarchy
410+
level of the memory hierarchy.
411+
412+
Parameters
413+
----------
414+
base : IndexedBase
415+
The base of the AbstractFunction subject of the TensorMove.
416+
tid0 : Dimension
417+
A representation of thread(s) issuing the TensorMove.
418+
coords : tuple
419+
The base address of the TensorMove (one point per Dimension).
411420
"""
412421

422+
__rargs__ = ('base', 'tid0', 'coords')
423+
424+
def __new__(cls, base, tid0, coords, **kwargs):
425+
return super().__new__(cls, base, tid0, coords)
426+
427+
@property
428+
def base(self):
429+
return self.args[0]
430+
431+
@property
432+
def tid0(self):
433+
return self.args[1]
434+
435+
@property
436+
def coords(self):
437+
return self.args[2]
438+
439+
@property
440+
def function(self):
441+
return self.base.function
442+
443+
@cached_property
444+
def indexed(self):
445+
return self.function[self.coords]
446+
447+
@property
448+
def ndim(self):
449+
return self.function.ndim
450+
413451
func = Reserved._rebuild
414452

415453
def _ccode(self, printer):

tests/test_data.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
)
99
from devito.data import LEFT, RIGHT, Decomposition, convert_index, loc_data_idx
1010
from devito.data.allocators import DataReference
11+
from devito.ir import ccode
1112
from devito.tools import as_tuple
1213
from devito.types import Scalar
14+
from devito.types.misc import TempArray
1315

1416

1517
class TestDataBasic:
@@ -336,6 +338,34 @@ def test_w_halo_w_autopadding(self):
336338
assert u1._size_nodomain == ((3, 3), (3, 3), (3, 9))
337339
assert u1.shape_allocated == (10, 10, 16)
338340

341+
@switchconfig(autopadding=True, platform='bdw')
342+
def test_temp_array_smart_padding_no_overshoot(self):
343+
mmts = configuration['platform'].max_mem_trans_size(np.float32)
344+
halo = 4
345+
z_size = 2*mmts - 2*halo
346+
347+
grid = Grid(shape=(4, 4, z_size))
348+
u = Function(name='u', grid=grid, space_order=halo)
349+
r = TempArray(name='r', dimensions=grid.dimensions, halo=u.halo, dtype=u.dtype)
350+
351+
z = grid.dimensions[-1]
352+
mapper = {z.symbolic_size: z_size}
353+
354+
assert r.padding[z][1].subs(mapper) == 0
355+
assert r.shape_allocated[-1].subs(mapper) == u.shape_allocated[-1]
356+
357+
@switchconfig(autopadding=True, platform='bdw')
358+
def test_temp_array_smart_padding_codegen_avoids_negative_mod(self):
359+
grid = Grid(shape=(4, 4, 592))
360+
u = Function(name='u', grid=grid, space_order=0)
361+
r = TempArray(name='r', dimensions=grid.dimensions, halo=u.halo, dtype=u.dtype)
362+
363+
code = ccode(r.shape_allocated[-1])
364+
365+
assert 'ROUND_UP(' in code
366+
assert '(-z_size)' not in code
367+
assert 'z_size' in code
368+
339369
def test_w_halo_custom(self):
340370
grid = Grid(shape=(4, 4))
341371

tests/test_symbolics.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
from devito import ( # noqa
99
Abs, Conj, Constant, Dimension, Eq, Function, Ge, Grid, Gt, Imag, Le, Lt, Max, Min,
1010
Operator, Real, SubDimension, SubDomain, TimeFunction, configuration, cos, norm, sin,
11-
solve
11+
solve, switchconfig
1212
)
1313
from devito.finite_differences.differentiable import Mul, SafeInv, Weights
1414
from devito.ir import Expression, FindNodes, ccode
1515
from devito.ir.support.guards import GuardExpr, pairwise_or, simplify_and
1616
from devito.mpi.halo_scheme import HaloTouch
1717
from devito.symbolics import ( # noqa
1818
INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite,
19-
FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, Rvalue, SizeOf,
20-
VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions,
19+
FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue,
20+
SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions,
2121
retrieve_indexed, uxreplace
2222
)
2323
from devito.tools import CustomDtype, as_tuple
@@ -390,6 +390,20 @@ def test_safeinv():
390390
assert str(v) == 'u[x, y]'
391391

392392

393+
def test_roundup():
394+
grid = Grid(shape=(11, 11))
395+
u = Function(name='u', grid=grid)
396+
a = dSymbol('a', dtype=np.int32)
397+
398+
expr = RoundUp(a, 16)
399+
with switchconfig(platform='bdw', language='openmp'):
400+
op = Operator(Eq(u, u + expr))
401+
402+
assert ccode(expr) == 'ROUND_UP(a, 16)'
403+
assert '#define ROUND_UP(a,b)' in str(op)
404+
assert 'ROUND_UP(a, 16)' in str(op)
405+
406+
393407
def test_def_function():
394408
foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
395409
foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int'])

0 commit comments

Comments
 (0)