Skip to content

Commit 5582088

Browse files
committed
compiler: Enhance detect_accesses
1 parent bd39274 commit 5582088

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

devito/ir/support/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search
77
from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split
88
from devito.types import (
9-
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension
9+
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove
1010
)
1111

1212
__all__ = [
@@ -137,7 +137,14 @@ def detect_accesses(exprs):
137137
"""
138138
# Compute M : F -> S
139139
mapper = defaultdict(Stencil)
140-
for e in retrieve_indexed(exprs, deep=True):
140+
141+
# Search among the Indexeds (Most accesses typically stem from Indexeds)
142+
plain_indexeds = retrieve_indexed(exprs, deep=True)
143+
144+
# Search among higher order objects, which still represent meaningful accesses
145+
high_order_indexeds = [i.indexed for i in search(exprs, TensorMove)]
146+
147+
for e in (*plain_indexeds, *high_order_indexeds):
141148
f = e.function
142149

143150
for a, d0 in zip(e.indices, f.dimensions, strict=False):
@@ -164,13 +171,16 @@ def detect_accesses(exprs):
164171
d, others = split(dims, lambda i: d0 in i._defines) # noqa: B023
165172

166173
if any(i.is_Indexed for i in a.args) or len(d) != 1:
167-
# Case 1) -- with indirect accesses there's not much we can infer
174+
# Case 1) -- with indirect accesses there's not much we
175+
# can infer
168176
continue
169177
else:
170178
# Case 2)
171179
d, = d
172180
_, o = split(others, lambda i: i.is_Custom)
173-
off = sum(i for i in a.args if i.is_integer or i.free_symbols & o)
181+
off = sum(
182+
i for i in a.args if i.is_integer or i.free_symbols & o
183+
)
174184
else:
175185
d, = dims
176186

devito/types/parallel.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,47 @@ class TensorMove(Expr, Reserved, Terminal):
407407

408408
"""
409409
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
410-
level of the memory hierarchy
410+
level of the memory hierarchy.
411+
412+
Parameters
413+
----------
414+
base : IndexedBase
415+
The base of the AbstractFunction subject of the TensorMove.
416+
tid0 : Dimension
417+
A representation of thread(s) issuing the TensorMove.
418+
coords : tuple
419+
The base address of the TensorMove (one point per Dimension).
411420
"""
412421

422+
__rargs__ = ('base', 'tid0', 'coords')
423+
424+
def __new__(cls, base, tid0, coords, **kwargs):
425+
return super().__new__(cls, base, tid0, coords)
426+
427+
@property
428+
def base(self):
429+
return self.args[0]
430+
431+
@property
432+
def tid0(self):
433+
return self.args[1]
434+
435+
@property
436+
def coords(self):
437+
return self.args[2]
438+
439+
@property
440+
def function(self):
441+
return self.base.function
442+
443+
@cached_property
444+
def indexed(self):
445+
return self.function[self.coords]
446+
447+
@property
448+
def ndim(self):
449+
return self.function.ndim
450+
413451
func = Reserved._rebuild
414452

415453
def _ccode(self, printer):

0 commit comments

Comments
 (0)