-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sve] Add an e2e for linalg.matmul with mixed types #73773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sve] Add an e2e for linalg.matmul with mixed types #73773
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesApart from the test itself, this patch also updates a few patterns to Regression tests will be updated seperately while auditing Vector Full diff: https://github.com/llvm/llvm-project/pull/73773.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..96ec44fcd77677a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
- castResTy = VectorType::get(vecTy.getShape(), castResTy);
+ castResTy = vecTy.clone(castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
srcValues.push_back(transposeOp.getVector());
} else {
// This is a constant. Create a reverse transpose op for it.
- auto vectorType = VectorType::get(
- srcType.getShape(),
- cast<VectorType>(operand.getType()).getElementType());
+ auto vectorType =
+ srcType.clone(cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand, invOrder));
}
}
- auto vectorType = VectorType::get(
- srcType.getShape(),
+ auto vectorType = srcType.clone(
cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,10 @@ struct CanonicalizeContractMatmulToMMT final
Value trans =
rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
VectorType newType =
- VectorType::get(cast<VectorType>(trans.getType()).getShape(),
- cast<VectorType>(mat.getType()).getElementType());
+ cast<VectorType>(trans.getType())
+ .clone(cast<VectorType>(mat.getType()).getElementType());
+ // VectorType::get(cast<VectorType>(trans.getType()).getShape(),
+ // cast<VectorType>(mat.getType()).getElementType());
return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
}
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 2024da2a585d99f..d771d32d548bbe2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -1,8 +1,14 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
-// RUN: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
-// RUN: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
-// RUN: %mcr_aarch64_cmd -e=matmul_f32 -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
func.func @matmul_f32() {
// Matrix dimensions
@@ -40,29 +46,37 @@ func.func @matmul_f32() {
return
}
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
- // Step 1: Tile
- %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %func_op = get_parent_op %matmul : (!transform.any_op) -> !transform.op<"func.func">
- %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
- // Step 2: Vectorize
- %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
-
- // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.reduction_to_contract
- transform.apply_patterns.vector.transfer_permutation_patterns
- transform.apply_patterns.vector.lower_masked_transfers
- } : !transform.op<"func.func">
-
- // Step 4: Lower vector.contract to vector.fma
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- transform.apply_patterns.vector.lower_outerproduct
- } : !transform.op<"func.func">
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+
+ // Step 1: Tile
+ %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+ // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.lower_masked_transfers
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
}
func.func private @printMemrefF32(%ptr : tensor<*xf32>)
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesApart from the test itself, this patch also updates a few patterns to Regression tests will be updated seperately while auditing Vector Full diff: https://github.com/llvm/llvm-project/pull/73773.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..96ec44fcd77677a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
- castResTy = VectorType::get(vecTy.getShape(), castResTy);
+ castResTy = vecTy.clone(castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
srcValues.push_back(transposeOp.getVector());
} else {
// This is a constant. Create a reverse transpose op for it.
- auto vectorType = VectorType::get(
- srcType.getShape(),
- cast<VectorType>(operand.getType()).getElementType());
+ auto vectorType =
+ srcType.clone(cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand, invOrder));
}
}
- auto vectorType = VectorType::get(
- srcType.getShape(),
+ auto vectorType = srcType.clone(
cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,10 @@ struct CanonicalizeContractMatmulToMMT final
Value trans =
rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
VectorType newType =
- VectorType::get(cast<VectorType>(trans.getType()).getShape(),
- cast<VectorType>(mat.getType()).getElementType());
+ cast<VectorType>(trans.getType())
+ .clone(cast<VectorType>(mat.getType()).getElementType());
+ // VectorType::get(cast<VectorType>(trans.getType()).getShape(),
+ // cast<VectorType>(mat.getType()).getElementType());
return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
}
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 2024da2a585d99f..d771d32d548bbe2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -1,8 +1,14 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
-// RUN: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
-// RUN: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
-// RUN: %mcr_aarch64_cmd -e=matmul_f32 -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
func.func @matmul_f32() {
// Matrix dimensions
@@ -40,29 +46,37 @@ func.func @matmul_f32() {
return
}
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
- // Step 1: Tile
- %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %func_op = get_parent_op %matmul : (!transform.any_op) -> !transform.op<"func.func">
- %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
- // Step 2: Vectorize
- %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
-
- // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.reduction_to_contract
- transform.apply_patterns.vector.transfer_permutation_patterns
- transform.apply_patterns.vector.lower_masked_transfers
- } : !transform.op<"func.func">
-
- // Step 4: Lower vector.contract to vector.fma
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- transform.apply_patterns.vector.lower_outerproduct
- } : !transform.op<"func.func">
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+
+ // Step 1: Tile
+ %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+ // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.lower_masked_transfers
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
}
func.func private @printMemrefF32(%ptr : tensor<*xf32>)
|
Depends on #73771 - please only review the top commit 🙏🏻 . |
23a0e7d
to
cd30cfe
Compare
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast | |||
|
|||
Type castResTy = getElementTypeOrSelf(op->getResult(0)); | |||
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType())) | |||
castResTy = VectorType::get(vecTy.getShape(), castResTy); | |||
castResTy = vecTy.clone(castResTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about a long term solution for this beyond these fixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
cd30cfe
to
4d199fb
Compare
Apart from the test itself, this patch also updates a few patterns to fix how new VectorType(s) are created. Namely, it makes sure that "scalability" is correctly propagated. Regression tests will be updated seperately while auditing Vector dialect tests in the context of scalable vectors: * https://github.com/orgs/llvm/projects/23
4d199fb
to
b66226c
Compare
Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.
Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors: