Skip to content

[mlir][vector] Fix invalid IR in ContractionOpLowering #78130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 16, 2024

Conversation

matthias-springer
Copy link
Member

If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir when running with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

  * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
    ** Insert  : 'vector.transpose'(0x5625b3a8cb30)
    ** Insert  : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change

Note: vector-contract-to-outerproduct-transforms-unsupported.mlir is merged into vector-contract-to-outerproduct-matvec-transforms.mlir. The greedy pattern application failed error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2024

@llvm/pr-subscribers-mlir-vector

Author: Matthias Springer (matthias-springer)

Changes

If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir when running with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

  * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
    ** Insert  : 'vector.transpose'(0x5625b3a8cb30)
    ** Insert  : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change

Note: vector-contract-to-outerproduct-transforms-unsupported.mlir is merged into vector-contract-to-outerproduct-matvec-transforms.mlir. The greedy pattern application failed error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.


Full diff: https://github.com/llvm/llvm-project/pull/78130.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+57-41)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+16-4)
  • (removed) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir (-35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d247..5310b9689a3505 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
   }
 
   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
-                             VectorType lhsType, int reductionDim,
+                             VectorType lhsType, int reductionSize,
                              std::optional<Value> maybeMask = std::nullopt) {
-    // Unrolling a scalable dimension would be incorrect - bail out.
-    if (lhsType.getScalableDims()[reductionDim])
-      return failure();
-
-    int reductionSize = lhsType.getDimSize(reductionDim);
-    assert(reductionSize > 0 &&
-           "Reduction dim must be a known static size to allow unrolling");
-
     // Incremental support for masking.
     if (mask && !maybeMask.has_value())
       return failure();
@@ -458,6 +450,17 @@ struct UnrolledOuterProductGenerator
     return res;
   }
 
+  std::optional<int64_t> getReductionSize(VectorType vecType,
+                                          int64_t reductionDim) {
+    // Cannot unroll scalable dimension.
+    if (vecType.getScalableDims()[reductionDim])
+      return std::nullopt;
+    int64_t reductionSize = vecType.getDimSize(reductionDim);
+    assert(reductionSize > 0 &&
+           "Reduction dim must be a known static size to allow unrolling");
+    return reductionSize;
+  }
+
   /// Two outer parallel, one inner reduction (matmat flavor).
   FailureOr<Value> matmat() {
     if (!iters({Par(), Par(), Red()}))
@@ -465,42 +468,52 @@ struct UnrolledOuterProductGenerator
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
-    Value transposedMask = t(mask, {2, 0, 1});
+
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
-      Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tlhs = t(lhs);
+        return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
+      }
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, t(rhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(rhs, t(lhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
-      Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value trhs = t(rhs);
+        return outerProd(trhs, t(lhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
+      }
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     return failure();
   }
 
@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator
       return failure();
     AffineExpr m, k;
     bindDims(rewriter.getContext(), m, k);
-    Value transposedMask = t(mask);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, t(mask));
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, t(mask));
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, t(mask));
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, t(mask));
     return failure();
   }
 
@@ -547,16 +559,20 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2f..5c8527f77e3df0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[A]]
+  // CHECK-DAG:     vector.transpose %[[MASK]]
+  // CHECK-DAG:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
       vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[A]]
+  // CHECK-DAG:     vector.transpose %[[MASK]]
+  // CHECK-DAG:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
       vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK:         vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // ============================================================================
 //  TD sequence
 // ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b3..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
-  indexing_maps = #matvec_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error@below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
-                                    %arg1: vector<[3]xf32>,
-                                    %arg2: vector<[2]xf32>,
-                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
-          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
-  return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    } : !transform.any_op
-    transform.yield
-  }
-}

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir when running with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

  * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -&gt; ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
    ** Insert  : 'vector.transpose'(0x5625b3a8cb30)
    ** Insert  : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
  } -&gt; failure : pattern failed to match
} -&gt; failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change

Note: vector-contract-to-outerproduct-transforms-unsupported.mlir is merged into vector-contract-to-outerproduct-matvec-transforms.mlir. The greedy pattern application failed error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.


Full diff: https://github.com/llvm/llvm-project/pull/78130.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+57-41)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+16-4)
  • (removed) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir (-35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d247..5310b9689a3505 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
   }
 
   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
-                             VectorType lhsType, int reductionDim,
+                             VectorType lhsType, int reductionSize,
                              std::optional<Value> maybeMask = std::nullopt) {
-    // Unrolling a scalable dimension would be incorrect - bail out.
-    if (lhsType.getScalableDims()[reductionDim])
-      return failure();
-
-    int reductionSize = lhsType.getDimSize(reductionDim);
-    assert(reductionSize > 0 &&
-           "Reduction dim must be a known static size to allow unrolling");
-
     // Incremental support for masking.
     if (mask && !maybeMask.has_value())
       return failure();
@@ -458,6 +450,17 @@ struct UnrolledOuterProductGenerator
     return res;
   }
 
+  std::optional<int64_t> getReductionSize(VectorType vecType,
+                                          int64_t reductionDim) {
+    // Cannot unroll scalable dimension.
+    if (vecType.getScalableDims()[reductionDim])
+      return std::nullopt;
+    int64_t reductionSize = vecType.getDimSize(reductionDim);
+    assert(reductionSize > 0 &&
+           "Reduction dim must be a known static size to allow unrolling");
+    return reductionSize;
+  }
+
   /// Two outer parallel, one inner reduction (matmat flavor).
   FailureOr<Value> matmat() {
     if (!iters({Par(), Par(), Red()}))
@@ -465,42 +468,52 @@ struct UnrolledOuterProductGenerator
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
-    Value transposedMask = t(mask, {2, 0, 1});
+
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
-      Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tlhs = t(lhs);
+        return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
+      }
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, t(rhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(rhs, t(lhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
-      Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value trhs = t(rhs);
+        return outerProd(trhs, t(lhs), res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
+      }
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize,
+                         t(mask, {2, 0, 1}));
     return failure();
   }
 
@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator
       return failure();
     AffineExpr m, k;
     bindDims(rewriter.getContext(), m, k);
-    Value transposedMask = t(mask);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, t(mask));
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, t(mask));
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, t(mask));
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, t(mask));
     return failure();
   }
 
@@ -547,16 +559,20 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2f..5c8527f77e3df0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[A]]
+  // CHECK-DAG:     vector.transpose %[[MASK]]
+  // CHECK-DAG:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
       vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[A]]
+  // CHECK-DAG:     vector.transpose %[[MASK]]
+  // CHECK-DAG:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
       vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK:         vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // ============================================================================
 //  TD sequence
 // ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b3..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
-  indexing_maps = #matvec_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error@below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
-                                    %arg1: vector<[3]xf32>,
-                                    %arg2: vector<[2]xf32>,
-                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
-          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
-  return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    } : !transform.any_op
-    transform.yield
-  }
-}

/// Two outer parallel, one inner reduction (matmat flavor).
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
return failure();
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
Value transposedMask = t(mask, {2, 0, 1});

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping the variable is quite useful in terms of readability. Can we keep it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep it, but it must be nested inside the if (auto reductionSize = getReductionSize(...)) check. (Because the implementation creates IR and the if checks will determine if the pattern succeeds or fails.)

@@ -514,24 +527,23 @@ struct UnrolledOuterProductGenerator
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
Value transposedMask = t(mask);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, must be nested.

@hanhanW hanhanW requested a review from dcaballe January 16, 2024 00:52
transposedMask);
if (auto reductionSize = getReductionSize(lhsType, 1)) {
Value tlhs = t(lhs);
return outerProd(tlhs, t(rhs), res, lhsType, *reductionSize,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests have now become non-deterministic across different systems as the order of the commas is not specified.

The line just before was a way to remedy this, you'll need to calculate the mask (or t(rhs)) ahead of time in each case.

we could sprinkle CHECK-DAG in various places like you did but this does not scale, every test writer will need to be aware of this or they will end up writing new tests that work on their system and be very puzzled when build bots come back red.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing.

Can you please add a clear comment, before the first if, for whomever is going to want to factor the mask construction out in the future that they should not ?

If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes `Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
  * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
    ** Insert  : 'vector.transpose'(0x5625b3a8cb30)
    ** Insert  : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change
```

Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`. The `greedy pattern application failed` error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.
@matthias-springer matthias-springer merged commit c0a354d into llvm:main Jan 16, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
If a rewrite pattern returns "failure", it must not have modified the
IR. This commit fixes
`Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir`
when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
  * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
    ** Insert  : 'vector.transpose'(0x5625b3a8cb30)
    ** Insert  : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change
```

Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is
merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`.
The `greedy pattern application failed` error is not longer produced.
This error indicates that the greedy pattern rewrite did not
convergence; it does not mean that a pattern could not be applied.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants