Skip to content

Commit 8030aa9

Browse files
committed
shape_cast
1 parent d415b6a commit 8030aa9

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2210,7 +2210,9 @@ def Vector_CompressStoreOp :
22102210
}
22112211

22122212
def Vector_ShapeCastOp :
2213-
Vector_Op<"shape_cast", [Pure]>,
2213+
Vector_Op<"shape_cast", [Pure,
2214+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
2215+
]>,
22142216
Arguments<(ins AnyVectorOfAnyRank:$source)>,
22152217
Results<(outs AnyVectorOfAnyRank:$result)> {
22162218
let summary = "shape_cast casts between vector shapes";

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5302,6 +5302,11 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
53025302
// ShapeCastOp
53035303
//===----------------------------------------------------------------------===//
53045304

5305+
void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5306+
SetIntRangeFn setResultRanges) {
5307+
setResultRanges(getResult(), argRanges.front());
5308+
}
5309+
53055310
/// Returns true if each element of 'a' is equal to the product of a contiguous
53065311
/// sequence of the elements of 'b'. Returns false otherwise.
53075312
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ func.func @vector_broadcast() -> vector<4x16xindex> {
3636
func.return %2 : vector<4x16xindex>
3737
}
3838

39+
// CHECK-LABEL: func @vector_shape_cast
40+
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
41+
func.func @vector_shape_cast() -> vector<4x4xindex> {
42+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
43+
%1 = vector.shape_cast %0 : vector<16xindex> to vector<4x4xindex>
44+
%2 = test.reflect_bounds %1 : vector<4x4xindex>
45+
func.return %2 : vector<4x4xindex>
46+
}
47+
3948
// CHECK-LABEL: func @vector_extract
4049
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
4150
func.func @vector_extract() -> index {

0 commit comments

Comments
 (0)