-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Fix invalid IR in ContractionOpLowering
#78130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Fix invalid IR in ContractionOpLowering
#78130
Conversation
@llvm/pr-subscribers-mlir-vector Author: Matthias Springer (matthias-springer) ChangesIf a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes
Note: Full diff: https://github.com/llvm/llvm-project/pull/78130.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d247..5310b9689a3505 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
}
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
- VectorType lhsType, int reductionDim,
+ VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
- // Unrolling a scalable dimension would be incorrect - bail out.
- if (lhsType.getScalableDims()[reductionDim])
- return failure();
-
- int reductionSize = lhsType.getDimSize(reductionDim);
- assert(reductionSize > 0 &&
- "Reduction dim must be a known static size to allow unrolling");
-
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
@@ -458,6 +450,17 @@ struct UnrolledOuterProductGenerator
return res;
}
+ std::optional<int64_t> getReductionSize(VectorType vecType,
+ int64_t reductionDim) {
+ // Cannot unroll scalable dimension.
+ if (vecType.getScalableDims()[reductionDim])
+ return std::nullopt;
+ int64_t reductionSize = vecType.getDimSize(reductionDim);
+ assert(reductionSize > 0 &&
+ "Reduction dim must be a known static size to allow unrolling");
+ return reductionSize;
+ }
+
/// Two outer parallel, one inner reduction (matmat flavor).
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
@@ -465,42 +468,52 @@ struct UnrolledOuterProductGenerator
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
- Value transposedMask = t(mask, {2, 0, 1});
+
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
- Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tlhs = t(lhs);
+ return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
+ }
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, t(rhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(rhs, t(lhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
- Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value trhs = t(rhs);
+ return outerProd(trhs, t(lhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
+ }
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
return failure();
}
@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
- Value transposedMask = t(mask);
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, t(mask));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, t(mask));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, t(mask));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, t(mask));
return failure();
}
@@ -547,16 +559,20 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2f..5c8527f77e3df0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
%x: vector<2xf32>,
%b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[A]]
+ // CHECK-DAG: vector.transpose %[[MASK]]
+ // CHECK-DAG: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
%x: vector<2xf32>,
%b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[A]]
+ // CHECK-DAG: vector.transpose %[[MASK]]
+ // CHECK-DAG: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK: vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// ============================================================================
// TD sequence
// ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b3..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error@below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
- %arg1: vector<[3]xf32>,
- %arg2: vector<[2]xf32>,
- %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
- : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
- return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %f = transform.structured.match ops{["func.func"]} in %module_op
- : (!transform.any_op) -> !transform.any_op
-
- transform.apply_patterns to %f {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.any_op
- transform.yield
- }
-}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesIf a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes
Note: Full diff: https://github.com/llvm/llvm-project/pull/78130.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d247..5310b9689a3505 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
}
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
- VectorType lhsType, int reductionDim,
+ VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
- // Unrolling a scalable dimension would be incorrect - bail out.
- if (lhsType.getScalableDims()[reductionDim])
- return failure();
-
- int reductionSize = lhsType.getDimSize(reductionDim);
- assert(reductionSize > 0 &&
- "Reduction dim must be a known static size to allow unrolling");
-
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
@@ -458,6 +450,17 @@ struct UnrolledOuterProductGenerator
return res;
}
+ std::optional<int64_t> getReductionSize(VectorType vecType,
+ int64_t reductionDim) {
+ // Cannot unroll scalable dimension.
+ if (vecType.getScalableDims()[reductionDim])
+ return std::nullopt;
+ int64_t reductionSize = vecType.getDimSize(reductionDim);
+ assert(reductionSize > 0 &&
+ "Reduction dim must be a known static size to allow unrolling");
+ return reductionSize;
+ }
+
/// Two outer parallel, one inner reduction (matmat flavor).
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
@@ -465,42 +468,52 @@ struct UnrolledOuterProductGenerator
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
- Value transposedMask = t(mask, {2, 0, 1});
+
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
- Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value tlhs = t(lhs);
+ return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
+ }
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, t(rhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(rhs, t(lhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
- Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1)) {
+ Value trhs = t(rhs);
+ return outerProd(trhs, t(lhs), res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
+ }
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize,
+ t(mask, {2, 0, 1}));
return failure();
}
@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
- Value transposedMask = t(mask);
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, t(mask));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, t(mask));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, t(mask));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
- transposedMask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, t(mask));
return failure();
}
@@ -547,16 +559,20 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 1))
+ return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+ if (auto reductionSize = getReductionSize(lhsType, 0))
+ return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2f..5c8527f77e3df0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
%x: vector<2xf32>,
%b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[A]]
+ // CHECK-DAG: vector.transpose %[[MASK]]
+ // CHECK-DAG: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
%x: vector<2xf32>,
%b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[A]]
+ // CHECK-DAG: vector.transpose %[[MASK]]
+ // CHECK-DAG: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK: vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// ============================================================================
// TD sequence
// ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b3..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error@below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
- %arg1: vector<[3]xf32>,
- %arg2: vector<[2]xf32>,
- %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
- : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
- return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %f = transform.structured.match ops{["func.func"]} in %module_op
- : (!transform.any_op) -> !transform.any_op
-
- transform.apply_patterns to %f {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.any_op
- transform.yield
- }
-}
|
/// Two outer parallel, one inner reduction (matmat flavor). | ||
FailureOr<Value> matmat() { | ||
if (!iters({Par(), Par(), Red()})) | ||
return failure(); | ||
// Set up the parallel/reduction structure in the right form. | ||
AffineExpr m, n, k; | ||
bindDims(rewriter.getContext(), m, n, k); | ||
Value transposedMask = t(mask, {2, 0, 1}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keeping the variable is quite useful in terms of readability. Can we keep it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can keep it, but it must be nested inside the if (auto reductionSize = getReductionSize(...))
check. (Because the implementation creates IR and the if
checks will determine if the pattern succeeds or fails.)
@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator | |||
return failure(); | |||
AffineExpr m, k; | |||
bindDims(rewriter.getContext(), m, k); | |||
Value transposedMask = t(mask); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above, must be nested.
transposedMask); | ||
if (auto reductionSize = getReductionSize(lhsType, 1)) { | ||
Value tlhs = t(lhs); | ||
return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests have now become non-deterministic across different systems as the order of the commas is not specified.
The line just before was a way to remedy this, you'll need to calculate the mask (or t(rhs)) ahead of time in each case.
we could sprinkle CHECK-DAG in various places like you did but this does not scale, every test writer will need to be aware of this or they will end up writing new tests that work on their system and be very puzzled when build bots come back red.
d1d3e40
to
f24010d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing.
Can you please add a clear comment, before the first if
, for whomever is going to want to factor the mask construction out in the future that they should not ?
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
Outdated
Show resolved
Hide resolved
If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes `Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. ``` * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' { Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering" ** Insert : 'vector.transpose'(0x5625b3a8cb30) ** Insert : 'vector.transpose'(0x5625b3a8cbc0) "(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0 } -> failure : pattern failed to match } -> failure : pattern failed to match LLVM ERROR: pattern returned failure but IR did change ``` Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`. The `greedy pattern application failed` error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.
f24010d
to
503ec72
Compare
If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes `Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. ``` * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' { Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering" ** Insert : 'vector.transpose'(0x5625b3a8cb30) ** Insert : 'vector.transpose'(0x5625b3a8cbc0) "(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0 } -> failure : pattern failed to match } -> failure : pattern failed to match LLVM ERROR: pattern returned failure but IR did change ``` Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`. The `greedy pattern application failed` error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.
If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes
Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
when running withMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
.Note:
vector-contract-to-outerproduct-transforms-unsupported.mlir
is merged intovector-contract-to-outerproduct-matvec-transforms.mlir
. Thegreedy pattern application failed
error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.