Skip to content

[mlir][vector] Add a check to ensure input vector rank equals target shape rank #127706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 26, 2025
Merged
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
// Bail-out if rank(source) != rank(target). The main limitation here is the
// fact that `ExtractStridedSlice` requires the rank for the input and
// output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return rewriter.notifyMatchFailure(
op, "expected input vector rank to match target shape rank");
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>

// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return %0 : vector<3x2x2xf32>
}
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
// CHECK: return

func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
return %0 : vector<4xf32>
Expand Down