Skip to content

[mlir][sve] Add an e2e for linalg.matmul with mixed types #73773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2023

Conversation

banach-space
Copy link
Contributor

Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.

Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.

Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:


Full diff: https://github.com/llvm/llvm-project/pull/73773.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+8-8)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir (+42-28)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..96ec44fcd77677a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
 
     Type castResTy = getElementTypeOrSelf(op->getResult(0));
     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
-      castResTy = VectorType::get(vecTy.getShape(), castResTy);
+      castResTy = vecTy.clone(castResTy);
     auto *castOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
                         bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
         srcValues.push_back(transposeOp.getVector());
       } else {
         // This is a constant. Create a reverse transpose op for it.
-        auto vectorType = VectorType::get(
-            srcType.getShape(),
-            cast<VectorType>(operand.getType()).getElementType());
+        auto vectorType =
+            srcType.clone(cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
             operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
-    auto vectorType = VectorType::get(
-        srcType.getShape(),
+    auto vectorType = srcType.clone(
         cast<VectorType>(op->getResultTypes()[0]).getElementType());
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,10 @@ struct CanonicalizeContractMatmulToMMT final
         Value trans =
             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
         VectorType newType =
-            VectorType::get(cast<VectorType>(trans.getType()).getShape(),
-                            cast<VectorType>(mat.getType()).getElementType());
+            cast<VectorType>(trans.getType())
+                .clone(cast<VectorType>(mat.getType()).getElementType());
+        // VectorType::get(cast<VectorType>(trans.getType()).getShape(),
+        //                 cast<VectorType>(mat.getType()).getElementType());
         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
       }
       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 2024da2a585d99f..d771d32d548bbe2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -1,8 +1,14 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
-// RUN:   -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
-// RUN:   -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
-// RUN: %mcr_aarch64_cmd -e=matmul_f32 -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
 
 func.func @matmul_f32() {
   // Matrix dimensions
@@ -40,29 +46,37 @@ func.func @matmul_f32() {
   return
 }
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  // Step 1: Tile
-  %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
-  %func_op = get_parent_op %matmul : (!transform.any_op) -> !transform.op<"func.func">
-  %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-  // Step 2: Vectorize
-  %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops : (!transform.any_op) -> !transform.any_op
-  transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
-
-  // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
-  transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.reduction_to_contract
-    transform.apply_patterns.vector.transfer_permutation_patterns
-    transform.apply_patterns.vector.lower_masked_transfers
-  } : !transform.op<"func.func">
-
-  // Step 4: Lower vector.contract to vector.fma
-  transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    transform.apply_patterns.vector.lower_outerproduct
-  } : !transform.op<"func.func">
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile
+    %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.lower_masked_transfers
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
 }
 
 func.func private @printMemrefF32(%ptr : tensor<*xf32>)

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.

Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:


Full diff: https://github.com/llvm/llvm-project/pull/73773.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+8-8)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir (+42-28)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..96ec44fcd77677a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
 
     Type castResTy = getElementTypeOrSelf(op->getResult(0));
     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
-      castResTy = VectorType::get(vecTy.getShape(), castResTy);
+      castResTy = vecTy.clone(castResTy);
     auto *castOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
                         bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
         srcValues.push_back(transposeOp.getVector());
       } else {
         // This is a constant. Create a reverse transpose op for it.
-        auto vectorType = VectorType::get(
-            srcType.getShape(),
-            cast<VectorType>(operand.getType()).getElementType());
+        auto vectorType =
+            srcType.clone(cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
             operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
-    auto vectorType = VectorType::get(
-        srcType.getShape(),
+    auto vectorType = srcType.clone(
         cast<VectorType>(op->getResultTypes()[0]).getElementType());
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,10 @@ struct CanonicalizeContractMatmulToMMT final
         Value trans =
             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
         VectorType newType =
-            VectorType::get(cast<VectorType>(trans.getType()).getShape(),
-                            cast<VectorType>(mat.getType()).getElementType());
+            cast<VectorType>(trans.getType())
+                .clone(cast<VectorType>(mat.getType()).getElementType());
+        // VectorType::get(cast<VectorType>(trans.getType()).getShape(),
+        //                 cast<VectorType>(mat.getType()).getElementType());
         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
       }
       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 2024da2a585d99f..d771d32d548bbe2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -1,8 +1,14 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
-// RUN:   -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
-// RUN:   -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
-// RUN: %mcr_aarch64_cmd -e=matmul_f32 -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
 
 func.func @matmul_f32() {
   // Matrix dimensions
@@ -40,29 +46,37 @@ func.func @matmul_f32() {
   return
 }
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  // Step 1: Tile
-  %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
-  %func_op = get_parent_op %matmul : (!transform.any_op) -> !transform.op<"func.func">
-  %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-  // Step 2: Vectorize
-  %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops : (!transform.any_op) -> !transform.any_op
-  transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
-
-  // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
-  transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.reduction_to_contract
-    transform.apply_patterns.vector.transfer_permutation_patterns
-    transform.apply_patterns.vector.lower_masked_transfers
-  } : !transform.op<"func.func">
-
-  // Step 4: Lower vector.contract to vector.fma
-  transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    transform.apply_patterns.vector.lower_outerproduct
-  } : !transform.op<"func.func">
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile
+    %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.lower_masked_transfers
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
 }
 
 func.func private @printMemrefF32(%ptr : tensor<*xf32>)

@banach-space
Copy link
Contributor Author

Depends on #73771 - please only review the top commit 🙏🏻 .

@banach-space banach-space force-pushed the andrzej/add_mixed_type_test branch from 23a0e7d to cd30cfe Compare November 29, 2023 10:02
@banach-space banach-space changed the title [mlir[[sve] Add an e2e for linalg.matmul with mixed types [mlir][sve] Add an e2e for linalg.matmul with mixed types Nov 29, 2023
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast

Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
castResTy = VectorType::get(vecTy.getShape(), castResTy);
castResTy = vecTy.clone(castResTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about a long term solution for this beyond these fixes.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@banach-space banach-space force-pushed the andrzej/add_mixed_type_test branch from cd30cfe to 4d199fb Compare November 29, 2023 15:43
Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.

Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:
  * https://github.com/orgs/llvm/projects/23
@banach-space banach-space force-pushed the andrzej/add_mixed_type_test branch from 4d199fb to b66226c Compare November 29, 2023 15:46
@banach-space banach-space merged commit 4b2ba5a into llvm:main Nov 29, 2023
@banach-space banach-space deleted the andrzej/add_mixed_type_test branch March 16, 2024 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants