Skip to content

Commit cd116ee

Browse files
ivangarcia44Ivan Garcia
andauthored
Add examples for reinterpret_cast and subview operators to show their behavior in relation to their input memref underlying memory and view (llvm#135244)
While working on llvm#134845 I was trying to understand the difference of how the reinterpret_cast and subview operators see the input memref, but it was not clear to me. I did a couple of experiments in which I learned that the subview takes into account the view of the input memref to create the view of the output memref, while the reinterpret_cast just uses the underlying memory of the input memref. I thought it would help future readers to see these two experiements as examples in the documentation to quickly figure out the difference between these two operators. @matthias-springer @joker-eph @sahas3 @Hanumanth04 @dixinzhou @rafaelubalmw --------- Co-authored-by: Ivan Garcia <[email protected]>
1 parent c8b3d79 commit cd116ee

File tree

1 file changed

+135
-6
lines changed

1 file changed

+135
-6
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,7 +1331,79 @@ def MemRef_ReinterpretCastOp
13311331
let description = [{
13321332
Modify offset, sizes and strides of an unranked/ranked memref.
13331333

1334-
Example:
1334+
Example 1:
1335+
1336+
Consecutive `reinterpret_cast` operations on memref's with static
1337+
dimensions.
1338+
1339+
We distinguish between *underlying memory* — the sequence of elements as
1340+
they appear in the contiguous memory of the memref — and the
1341+
*strided memref*, which refers to the underlying memory interpreted
1342+
according to specified offsets, sizes, and strides.
1343+
1344+
```mlir
1345+
%result1 = memref.reinterpret_cast %arg0 to
1346+
offset: [9],
1347+
sizes: [4, 4],
1348+
strides: [16, 2]
1349+
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
1350+
memref<4x4xf32, strided<[16, 2], offset: 9>>
1351+
1352+
%result2 = memref.reinterpret_cast %result1 to
1353+
offset: [0],
1354+
sizes: [2, 2],
1355+
strides: [4, 2]
1356+
: memref<4x4xf32, strided<[16, 2], offset: 9>> to
1357+
memref<2x2xf32, strided<[4, 2], offset: 0>>
1358+
```
1359+
1360+
The underlying memory of `%arg0` consists of a linear sequence of integers
1361+
from 1 to 64. Its memref has the following 8x8 elements:
1362+
1363+
```mlir
1364+
[[1, 2, 3, 4, 5, 6, 7, 8],
1365+
[9, 10, 11, 12, 13, 14, 15, 16],
1366+
[17, 18, 19, 20, 21, 22, 23, 24],
1367+
[25, 26, 27, 28, 29, 30, 31, 32],
1368+
[33, 34, 35, 36, 37, 38, 39, 40],
1369+
[41, 42, 43, 44, 45, 46, 47, 48],
1370+
[49, 50, 51, 52, 53, 54, 55, 56],
1371+
[57, 58, 59, 60, 61, 62, 63, 64]]
1372+
```
1373+
1374+
Following the first `reinterpret_cast`, the strided memref elements
1375+
of `%result1` are:
1376+
1377+
```mlir
1378+
[[10, 12, 14, 16],
1379+
[26, 28, 30, 32],
1380+
[42, 44, 46, 48],
1381+
[58, 60, 62, 64]]
1382+
```
1383+
1384+
Note: The offset and strides are relative to the underlying memory of
1385+
`%arg0`.
1386+
1387+
The second `reinterpret_cast` results in the following strided memref
1388+
for `%result2`:
1389+
1390+
```mlir
1391+
[[1, 3],
1392+
[5, 7]]
1393+
```
1394+
1395+
Notice that it does not matter if you use %result1 or %arg0 as a source
1396+
for the second `reinterpret_cast` operation. Only the underlying memory
1397+
pointers will be reused.
1398+
1399+
The offset and stride are relative to the base underlying memory of the
1400+
memref, starting at 1, not at 10 as seen in the output of `%result1`.
1401+
This behavior contrasts with the `subview` operator, where values are
1402+
relative to the strided memref (refer to `subview` examples).
1403+
Consequently, the second `reinterpret_cast` behaves as if `%arg0` were
1404+
passed directly as its argument.
1405+
1406+
Example 2:
13351407
```mlir
13361408
memref.reinterpret_cast %ranked to
13371409
offset: [0],
@@ -1898,6 +1970,64 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
18981970

18991971
Example 1:
19001972

1973+
Consecutive `subview` operations on memref's with static dimensions.
1974+
1975+
We distinguish between *underlying memory* — the sequence of elements as
1976+
they appear in the contiguous memory of the memref — and the
1977+
*strided memref*, which refers to the underlying memory interpreted
1978+
according to specified offsets, sizes, and strides.
1979+
1980+
```mlir
1981+
%result1 = memref.subview %arg0[1, 1][4, 4][2, 2]
1982+
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
1983+
memref<4x4xf32, strided<[16, 2], offset: 9>>
1984+
1985+
%result2 = memref.subview %result1[1, 1][2, 2][2, 2]
1986+
: memref<4x4xf32, strided<[16, 2], offset: 9>> to
1987+
memref<2x2xf32, strided<[32, 4], offset: 27>>
1988+
```
1989+
1990+
The underlying memory of `%arg0` consists of a linear sequence of integers
1991+
from 1 to 64. Its memref has the following 8x8 elements:
1992+
1993+
```mlir
1994+
[[1, 2, 3, 4, 5, 6, 7, 8],
1995+
[9, 10, 11, 12, 13, 14, 15, 16],
1996+
[17, 18, 19, 20, 21, 22, 23, 24],
1997+
[25, 26, 27, 28, 29, 30, 31, 32],
1998+
[33, 34, 35, 36, 37, 38, 39, 40],
1999+
[41, 42, 43, 44, 45, 46, 47, 48],
2000+
[49, 50, 51, 52, 53, 54, 55, 56],
2001+
[57, 58, 59, 60, 61, 62, 63, 64]]
2002+
```
2003+
2004+
Following the first `subview`, the strided memref elements of `%result1`
2005+
are:
2006+
2007+
```mlir
2008+
[[10, 12, 14, 16],
2009+
[26, 28, 30, 32],
2010+
[42, 44, 46, 48],
2011+
[58, 60, 62, 64]]
2012+
```
2013+
2014+
Note: The offset and strides are relative to the strided memref of `%arg0`
2015+
(compare to the corresponding `reinterpret_cast` example).
2016+
2017+
The second `subview` results in the following strided memref for
2018+
`%result2`:
2019+
2020+
```mlir
2021+
[[28, 32],
2022+
[60, 64]]
2023+
```
2024+
2025+
Unlike the `reinterpret_cast`, the values are relative to the strided
2026+
memref of the input (`%result1` in this case) and not its
2027+
underlying memory.
2028+
2029+
Example 2:
2030+
19012031
```mlir
19022032
// Subview of static memref with strided layout at static offsets, sizes
19032033
// and strides.
@@ -1906,7 +2036,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19062036
memref<8x2xf32, strided<[21, 18], offset: 137>>
19072037
```
19082038

1909-
Example 2:
2039+
Example 3:
19102040

19112041
```mlir
19122042
// Subview of static memref with identity layout at dynamic offsets, sizes
@@ -1915,7 +2045,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19152045
: memref<64x4xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
19162046
```
19172047

1918-
Example 3:
2048+
Example 4:
19192049

19202050
```mlir
19212051
// Subview of dynamic memref with strided layout at dynamic offsets and
@@ -1925,7 +2055,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19252055
memref<4x4xf32, strided<[?, ?], offset: ?>>
19262056
```
19272057

1928-
Example 4:
2058+
Example 5:
19292059

19302060
```mlir
19312061
// Rank-reducing subviews.
@@ -1935,14 +2065,13 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19352065
: memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
19362066
```
19372067

1938-
Example 5:
2068+
Example 6:
19392069

19402070
```mlir
19412071
// Identity subview. The subview is the full source memref.
19422072
%1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1]
19432073
: memref<8x16x4xf32> to memref<8x16x4xf32>
19442074
```
1945-
19462075
}];
19472076

19482077
let arguments = (ins AnyMemRef:$source,

0 commit comments

Comments
 (0)