Skip to content

Commit bd5e0b1

Browse files
committed
[vector][mlir] Add required comments and test in vectorUnroll.cpp and invalid.mlir respectively
1 parent 5fd68e5 commit bd5e0b1

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ struct UnrollElementwisePattern : public RewritePattern {
437437
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
438438
SmallVector<int64_t> originalSize =
439439
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
440+
// Bail-out if rank(source) != rank(target). The main limitation here is the
441+
// fact that `ExtractStridedSlice` requires the rank for the input and
442+
// output to match. If needed, we can relax this later.
440443
if (originalSize.size() != targetShape->size())
441444
return failure();
442445
Location loc = op->getLoc();

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,16 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
766766
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
767767
}
768768

769+
// -----
770+
771+
func.func @extract_strided_slice() -> () {
772+
// expected-error@+1 {{expected input vector rank to match target shape rank}}
773+
%0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
774+
%1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
775+
vector<24x2x2xf32> to vector<2x2xf32>
776+
return
777+
}
778+
769779
// -----
770780

771781
#contraction_accesses = [

0 commit comments

Comments
 (0)