Skip to content

Add examples for reinterpret_cast and subview operators to show their behavior in relation to their input memref underlying memory and view #135244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def MemRef_ReinterpretCastOp
let description = [{
Modify offset, sizes and strides of an unranked/ranked memref.

Example:
Example 1:
```mlir
memref.reinterpret_cast %ranked to
offset: [0],
Expand Down Expand Up @@ -1363,6 +1363,58 @@ def MemRef_ReinterpretCastOp
%dst.sizes = %sizes
%dst.strides = %strides
```

Example 2:

Consecutive `reinterpret_cast` operations on memref's with static dimensions.

We distinguish between *underlying memory* — the sequence of elements as
they appear in the contiguous memory of the memref — and the *view*, which refers to
the underlying memory interpreted according to specified offsets, sizes, and strides.

```mlir
%result1 = memref.reinterpret_cast %arg0 to offset: [9], sizes: [4, 4], strides: [16, 2] : memref<8x8xf32, strided<[8, 1], offset: 0>> to memref<4x4xf32, strided<[16, 2], offset: 9>>

%result2 = memref.reinterpret_cast %result1 to offset: [0], sizes: [2, 2], strides: [4, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2x2xf32, strided<[4, 2], offset: 0>>
```

The input memref `%arg0` has the following view. The underlying memory consists
of a linear sequence of integers from 1 to 64:

```mlir
[[1, 2, 3, 4, 5, 6, 7, 8],
[9, 10, 11, 12, 13, 14, 15, 16],
[17, 18, 19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30, 31, 32],
[33, 34, 35, 36, 37, 38, 39, 40],
[41, 42, 43, 44, 45, 46, 47, 48],
[49, 50, 51, 52, 53, 54, 55, 56],
[57, 58, 59, 60, 61, 62, 63, 64]]
```

Following the first `reinterpret_cast`, the view of `%result1` is:

```mlir
[[10, 12, 14, 16],
[26, 28, 30, 32],
[42, 44, 46, 48],
[58, 60, 62, 64]]
```

Note: The offset and strides are relative to the underlying memory of `%arg0`.

The second `reinterpret_cast` results in the following view for `%result2`:

```mlir
[[1, 3],
[5, 7]]
```

It is important to observe that the offset and stride are relative to the base underlying
memory of the memref, starting at 1, not at 10 as seen in the output of `%result1`.
This behavior contrasts with the `subview` operator, where values are relative to the view of
the memref (refer to `subview` examples). Consequently, the second `reinterpret_cast` behaves
as if `%arg0` were passed directly as its argument.
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", []>:$source,
Expand Down Expand Up @@ -1942,7 +1994,55 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
%1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1]
: memref<8x16x4xf32> to memref<8x16x4xf32>
```
Example 6:

Consecutive `subview` operations on memref's with static dimensions.

We distinguish between *underlying memory* — the sequence of elements as
they appear in the contiguous memory of the memref — and the *view*, which refers to
the underlying memory interpreted according to specified offsets, sizes, and strides.

```mlir
%result1 = memref.subview %arg0[1, 1][4, 4][2, 2] : memref<8x8xf32, strided<[8, 1], offset: 0>> to memref<4x4xf32, strided<[16, 2], offset: 9>>

%result2 = memref.subview %result1[1, 1][2, 2][2, 2] : memref<4x4xf32, strided<[16, 2], offset: 9>> to memref<2x2xf32, strided<[32, 4], offset: 27>>
```

The input memref `%arg0` has the following view. The underlying memory for this input
memref is a linear sequence of integers from 1 to 64:

```mlir
[[1, 2, 3, 4, 5, 6, 7, 8],
[9, 10, 11, 12, 13, 14, 15, 16],
[17, 18, 19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30, 31, 32],
[33, 34, 35, 36, 37, 38, 39, 40],
[41, 42, 43, 44, 45, 46, 47, 48],
[49, 50, 51, 52, 53, 54, 55, 56],
[57, 58, 59, 60, 61, 62, 63, 64]]
```

Following the first `subview`, the view of `%result1` is:

```mlir
[[10, 12, 14, 16],
[26, 28, 30, 32],
[42, 44, 46, 48],
[58, 60, 62, 64]]
```

Note: The offset and strides are relative to the memref view of `%arg0` (compare to the
corresponding `reinterpret_cast` example).

The second `subview` results in the following view for `%result2`:

```mlir
[[28, 32],
[60, 64]]
```

Unlike the `reinterpret_cast`, the values are relative to the view of the input memref
(`%result1` in this case) and not its underlying memory.
}];

let arguments = (ins AnyMemRef:$source,
Expand Down