Skip to content

Commit 33a9ce6

Browse files
[mlir][memref] Verify out-of-bounds access for memref.subview
1 parent 2fbfbf4 commit 33a9ce6

File tree

14 files changed

+311
-248
lines changed

14 files changed

+311
-248
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 40 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,11 +1859,11 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
18591859
]> {
18601860
let summary = "memref subview operation";
18611861
let description = [{
1862-
The "subview" operation converts a memref type to another memref type
1863-
which represents a reduced-size view of the original memref as specified by
1864-
the operation's offsets, sizes and strides arguments.
1862+
The `subview` operation converts a memref type to a memref type which
1863+
represents a reduced-size view of the original memref as specified by the
1864+
operation's offsets, sizes and strides arguments.
18651865

1866-
The SubView operation supports the following arguments:
1866+
The `subview` operation supports the following arguments:
18671867

18681868
* source: the "base" memref on which to create a "view" memref.
18691869
* offsets: memref-rank number of offsets into the "base" memref at which to
@@ -1876,118 +1876,65 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
18761876
The representation based on offsets, sizes and strides support a
18771877
partially-static specification via attributes specified through the
18781878
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
1879-
sentinel value ShapedType::kDynamic encodes that the corresponding entry has
1880-
a dynamic value.
1879+
sentinel value `ShapedType::kDynamic` encodes that the corresponding entry
1880+
has a dynamic value.
18811881

1882-
A subview operation may additionally reduce the rank of the resulting view
1883-
by removing dimensions that are statically known to be of size 1.
1882+
A `subview` operation may additionally reduce the rank of the resulting
1883+
view by removing dimensions that are statically known to be of size 1.
1884+
1885+
The offsets, sizes and strides must be in-bounds with respect to the source
1886+
memref. When possible, the static operation verifier will detect
1887+
out-of-bounds subviews. Subviews that cannot be confirmed to be in-bounds
1888+
or out-of-bounds based on compile-time information are valid. However,
1889+
performing an out-of-bounds subview at runtime is undefined behavior.
18841890

18851891
Example 1:
18861892

18871893
```mlir
1888-
%0 = memref.alloc() : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>
1889-
1890-
// Create a sub-view of "base" memref '%0' with offset arguments '%c0',
1891-
// dynamic sizes for each dimension, and stride arguments '%c1'.
1892-
%1 = memref.subview %0[%c0, %c0][%size0, %size1][%c1, %c1]
1893-
: memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to
1894-
memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)>>
1894+
// Subview of static memref with identity layout at dynamic offsets, sizes
1895+
// and strides.
1896+
%1 = memref.subview %0[%off0, %off1][%sz0, %sz1][%str0, %str1]
1897+
: memref<64x4xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
18951898
```
18961899

18971900
Example 2:
18981901

18991902
```mlir
1900-
%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
1901-
1902-
// Create a sub-view of "base" memref '%0' with dynamic offsets, sizes,
1903+
// Subview of static memref with strided layout at static offsets, sizes
19031904
// and strides.
1904-
// Note that dynamic offsets are represented by the linearized dynamic
1905-
// offset symbol 's0' in the subview memref layout map, and that the
1906-
// dynamic strides operands, after being applied to the base memref
1907-
// strides in each dimension, are represented in the view memref layout
1908-
// map as symbols 's1', 's2' and 's3'.
1909-
%1 = memref.subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z]
1910-
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
1911-
memref<?x?x?xf32,
1912-
affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
1905+
%1 = memref.subview %0[4, 2][8, 2][3, 2]
1906+
: memref<64x4xf32, strided<[7, 9], offset: 91>> to
1907+
memref<8x2xf32, strided<[21, 18], offset: 137>>
19131908
```
19141909

1915-
Example 3:
1910+
Example 4:
19161911

19171912
```mlir
1918-
%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
1919-
1920-
// Subview with constant offsets, sizes and strides.
1921-
%1 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
1922-
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
1923-
memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>
1913+
// Subview of dynamic memref with strided layout at dynamic offsets and
1914+
// strides, but static sizes.
1915+
%1 = memref.subview %0[%off0, %off1][4, 4][%str0, %str1]
1916+
: memref<?x?xf32, strided<[?, ?], offset: ?>> to
1917+
memref<4x4xf32, strided<[?, ?], offset: ?>>
19241918
```
19251919

1926-
Example 4:
1920+
Example 5:
19271921

19281922
```mlir
1929-
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
1930-
1931-
// Subview with constant size, but dynamic offsets and
1932-
// strides. The resulting memref has a static shape, but if the
1933-
// base memref has an affine map to describe the layout, the result
1934-
// memref also uses an affine map to describe the layout. The
1935-
// strides of the result memref is computed as follows:
1936-
//
1937-
// Let #map1 represents the layout of the base memref, and #map2
1938-
// represents the layout of the result memref. A #mapsubview can be
1939-
// constructed to map an index from the result memref to the base
1940-
// memref (note that the description below uses more convenient
1941-
// naming for symbols, while in affine maps, symbols are
1942-
// represented as unsigned numbers that identify that symbol in the
1943-
// given affine map.
1944-
//
1945-
// #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1)
1946-
//
1947-
// where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then,
1948-
//
1949-
// #map2 = #map1.compose(#mapsubview)
1950-
//
1951-
// If the layout map is represented as
1952-
//
1953-
// #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)
1954-
//
1955-
// then,
1956-
//
1957-
// #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] ->
1958-
// (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0)
1959-
//
1960-
// Representing this canonically
1961-
//
1962-
// #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)
1963-
//
1964-
// where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1.
1965-
%1 = memref.subview %0[%i, %j][4, 4][%x, %y] :
1966-
: memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>> to
1967-
memref<4x4xf32, affine_map<(d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>>
1968-
1969-
// Note that the subview op does not guarantee that the result
1970-
// memref is "inbounds" w.r.t to base memref. It is upto the client
1971-
// to ensure that the subview is accessed in a manner that is
1972-
// in-bounds.
1923+
// Rank-reducing subviews.
1924+
%1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1]
1925+
: memref<8x16x4xf32> to memref<16x4xf32>
1926+
%3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1]
1927+
: memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
19731928
```
19741929

1975-
Example 5:
1976-
1930+
Example 6:
1931+
19771932
```mlir
1978-
// Rank-reducing subview.
1979-
%1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] :
1980-
memref<8x16x4xf32> to memref<16x4xf32>
1981-
1982-
// Original layout:
1983-
// (d0, d1, d2) -> (64 * d0 + 16 * d1 + d2)
1984-
// Subviewed layout:
1985-
// (d0, d1, d2) -> (64 * (d0 + 3) + 4 * (d1 + 4) + d2 + 2) = (64 * d0 + 4 * d1 + d2 + 210)
1986-
// After rank reducing:
1987-
// (d0, d1) -> (4 * d0 + d1 + 210)
1988-
%3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] :
1989-
memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
1933+
// Identity subview. The subview is the full source memref.
1934+
%1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1]
1935+
: memref<8x16x4xf32> to memref<8x16x4xf32>
19901936
```
1937+
19911938
}];
19921939

19931940
let arguments = (ins AnyMemRef:$source,

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
4545

4646
namespace mlir {
4747

48+
/// Result for slice bounds verification;
49+
struct SliceBoundsVerificationResult {
50+
/// If set to "true", the slice bounds verification was successful.
51+
bool isValid;
52+
/// An error message that can be printed during op verification.
53+
std::string errorMessage;
54+
};
55+
56+
/// Verify that the offsets/sizes/strides-style access into the given shape
57+
/// is in-bounds. Only static values are verified. If `generateErrorMessage`
58+
/// is set to "true", an error message is produced that can be printed by the
59+
/// op verifier.
60+
SliceBoundsVerificationResult
61+
verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
62+
ArrayRef<int64_t> staticSizes,
63+
ArrayRef<int64_t> staticStrides,
64+
bool generateErrorMessage = false);
65+
SliceBoundsVerificationResult verifyInBoundsSlice(
66+
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
67+
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
68+
bool generateErrorMessage = false);
69+
4870
/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
4971
/// constant arguments. This pattern assumes that the op has a suitable builder
5072
/// that takes a result type, a "source" operand and mixed offsets, sizes and
@@ -72,11 +94,20 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
7294
failed(foldDynamicIndexList(mixedStrides)))
7395
return failure();
7496

75-
// Create the new op in canonical form.
97+
// Pattern does not apply if the produced op would not verify.
98+
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
99+
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
100+
mixedSizes, mixedStrides);
101+
if (!sliceResult.isValid)
102+
return failure();
103+
104+
// Compute the new result type.
76105
auto resultType =
77106
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
78107
if (!resultType)
79108
return failure();
109+
110+
// Create the new op in canonical form.
80111
auto newOp =
81112
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
82113
mixedOffsets, mixedSizes, mixedStrides);

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2977,6 +2977,9 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
29772977
LogicalResult SubViewOp::verify() {
29782978
MemRefType baseType = getSourceType();
29792979
MemRefType subViewType = getType();
2980+
ArrayRef<int64_t> staticOffsets = getStaticOffsets();
2981+
ArrayRef<int64_t> staticSizes = getStaticSizes();
2982+
ArrayRef<int64_t> staticStrides = getStaticStrides();
29802983

29812984
// The base memref and the view memref should be in the same memory space.
29822985
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
@@ -2991,7 +2994,7 @@ LogicalResult SubViewOp::verify() {
29912994
// Compute the expected result type, assuming that there are no rank
29922995
// reductions.
29932996
MemRefType expectedType = SubViewOp::inferResultType(
2994-
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2997+
baseType, staticOffsets, staticSizes, staticStrides);
29952998

29962999
// Verify all properties of a shaped type: rank, element type and dimension
29973000
// sizes. This takes into account potential rank reductions.
@@ -3025,6 +3028,14 @@ LogicalResult SubViewOp::verify() {
30253028
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
30263029
*this, expectedType);
30273030

3031+
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
3032+
// to the base memref.
3033+
SliceBoundsVerificationResult boundsResult =
3034+
verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3035+
staticStrides, /*generateErrorMessage=*/true);
3036+
if (!boundsResult.isValid)
3037+
return getOperation()->emitError(boundsResult.errorMessage);
3038+
30283039
return success();
30293040
}
30303041

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Interfaces/InferIntRangeInterface.h"
2828
#include "mlir/Interfaces/LoopLikeInterface.h"
2929
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
30+
#include "mlir/Interfaces/ViewLikeInterface.h"
3031
#include "mlir/Support/LLVM.h"
3132
#include "llvm/ADT/DenseSet.h"
3233
#include "llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
23522353
}
23532354
}
23542355

2355-
/// Verify that the offsets/sizes/strides-style access into the given tensor
2356-
/// is in-bounds. Only static information is verified.
2357-
static LogicalResult verifyInBoundsSlice(Operation *op,
2358-
RankedTensorType tensorType,
2359-
ArrayRef<int64_t> staticOffsets,
2360-
ArrayRef<int64_t> staticSizes,
2361-
ArrayRef<int64_t> staticStrides) {
2362-
for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
2363-
// Nothing to verify for dynamic source dims.
2364-
if (tensorType.isDynamicDim(i))
2365-
continue;
2366-
// Nothing to verify if the offset is dynamic.
2367-
if (ShapedType::isDynamic(staticOffsets[i]))
2368-
continue;
2369-
if (staticOffsets[i] >= tensorType.getDimSize(i))
2370-
return op->emitOpError("offset ")
2371-
<< i << " is out-of-bounds: " << staticOffsets[i]
2372-
<< " >= " << tensorType.getDimSize(i);
2373-
if (ShapedType::isDynamic(staticSizes[i]) ||
2374-
ShapedType::isDynamic(staticStrides[i]))
2375-
continue;
2376-
int64_t lastPos =
2377-
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
2378-
if (lastPos >= tensorType.getDimSize(i))
2379-
return op->emitOpError("slice along dimension ")
2380-
<< i << " runs out-of-bounds: " << lastPos
2381-
<< " >= " << tensorType.getDimSize(i);
2382-
}
2383-
return success();
2384-
}
2385-
23862356
/// Verifier for ExtractSliceOp.
23872357
LogicalResult ExtractSliceOp::verify() {
23882358
RankedTensorType sourceType = getSourceType();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
23962366

23972367
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
23982368
// to the source tensor.
2399-
return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
2400-
getStaticSizes(), getStaticStrides());
2369+
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2370+
sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2371+
getStaticStrides(), /*generateErrorMessage=*/true);
2372+
if (!boundsResult.isValid)
2373+
return getOperation()->emitError(boundsResult.errorMessage);
2374+
2375+
return success();
24012376
}
24022377

24032378
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2777,9 +2752,14 @@ LogicalResult InsertSliceOp::verify() {
27772752
return produceSliceErrorMsg(result, *this, expectedType);
27782753

27792754
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2780-
// to the source tensor.
2781-
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
2782-
getStaticSizes(), getStaticStrides());
2755+
// to the destination tensor.
2756+
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2757+
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2758+
getStaticStrides(), /*generateErrorMessage=*/true);
2759+
if (!boundsResult.isValid)
2760+
return getOperation()->emitError(boundsResult.errorMessage);
2761+
2762+
return success();
27832763
}
27842764

27852765
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2874,6 +2854,13 @@ class InsertSliceOpConstantArgumentFolder final
28742854
failed(foldDynamicStrideList(mixedStrides)))
28752855
return failure();
28762856

2857+
// Pattern does not apply if the produced op would not verify.
2858+
SliceBoundsVerificationResult sliceResult =
2859+
verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2860+
mixedOffsets, mixedSizes, mixedStrides);
2861+
if (!sliceResult.isValid)
2862+
return failure();
2863+
28772864
// Create the new op in canonical form.
28782865
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
28792866
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@@ -3802,9 +3789,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
38023789
return produceSliceErrorMsg(result, *this, expectedType);
38033790

38043791
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
3805-
// to the source tensor.
3806-
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
3807-
getStaticSizes(), getStaticStrides());
3792+
// to the destination tensor.
3793+
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3794+
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3795+
getStaticStrides(), /*generateErrorMessage=*/true);
3796+
if (!boundsResult.isValid)
3797+
return getOperation()->emitError(boundsResult.errorMessage);
3798+
3799+
return success();
38083800
}
38093801

38103802
void ParallelInsertSliceOp::getCanonicalizationPatterns(

0 commit comments

Comments
 (0)