-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Update OuterProductFusion
to account for recent changes
#102125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Use vector.interleave rather than the LLVM intrinsic - Remove dependency on LLVM dialect - Remove manual outerproduct erases (these are now trivially dead) - Remove comment explaining issues with previous tile allocator - Update pipeline in `multi-tile-matmul-mixed-types.mlir`
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) Changes
Full diff: https://github.com/llvm/llvm-project/pull/102125.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 9178655f010c9..3f1776f57e4c7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -910,11 +910,11 @@ def FMopa2WayOp
The 2 outer products in the example above can be fused into a single outer
product as follows:
- ```mlir
- %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
- %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+ ```mlir
+ %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
+ %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- ```
+ ```
This is implemented in the `-arm-sme-outer-product-fusion` pass.
@@ -1167,13 +1167,13 @@ def SMopa4WayOp
product as follows:
```mlir
- %lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+ %lhs0 = vector.interleave %a0, %a2 : vector<[4]xi8> -> vector<[8]xi8>
+ %lhs1 = vector.interleave %a1, %a3 : vector<[4]xi8> -> vector<[8]xi8>
+ %lhs = vector.interleave %lhs0, %lhs1 : vector<[8]xi8> -> vector<[16]xi8>
- %rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+ %rhs0 = vector.interleave %b0, %b2 : vector<[4]xi8> -> vector<[8]xi8>
+ %rhs1 = vector.interleave %b1, %b3 : vector<[4]xi8> -> vector<[8]xi8>
+ %rhs = vector.interleave %rhs0, %rhs1 : vector<[8]xi8> -> vector<[16]xi8>
%0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 921234daad1f1..45efabf5fe1b4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -180,7 +180,7 @@ def OuterProductFusion
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
}];
let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
- let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
+ let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
}
def VectorLegalization
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f9b5080e82db..a29624468ba2d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRPass
MLIRArmSMEDialect
MLIRFuncDialect
- MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRIndexDialect
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 1e711678dc9ab..ee1e374b25b04 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -80,15 +79,6 @@ static LogicalResult isCompatible(PatternRewriter &rewriter,
return success();
}
-// Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
-static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs) {
- auto inputType = cast<VectorType>(lhs.getType());
- VectorType inputTypeX2 =
- VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
- return rewriter.create<LLVM::vector_interleave2>(loc, inputTypeX2, lhs, rhs);
-}
-
// Fuse two 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 2-way outer product operation.
//
@@ -106,10 +96,8 @@ static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
//
// Becomes:
//
-// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
-// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
-// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
+// %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
class OuterProductFusion2Way
@@ -135,28 +123,7 @@ class OuterProductFusion2Way
if (!op1->hasOneUse()) {
// If the first outer product has uses other than as the input to another
- // outer product, it can't be erased after fusion. This is a problem when
- // it also has an accumulator as this will be used as the root for tile
- // allocation and since the widening outer product uses the same
- // accumulator it will get assigned the same tile ID, resulting in 3
- // outer products accumulating to the same tile and incorrect results.
- //
- // Example:
- //
- // %acc = arith.constant dense<0.0> ; root for tile allocation
- // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
- // vector.print %0 ; intermediary use, can't erase %0
- // %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
- //
- // After fusion and tile allocation
- //
- // %0 = arm_sme.zero {tile_id = 0 : i32}
- // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
- // vector.print %1
- // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
- //
- // No accumulator would be ok, but it's simpler to prevent this
- // altogether, since it has no benefit.
+ // outer product, it can't be erased after fusion.
return rewriter.notifyMatchFailure(op,
kMatchFailureOuterProductNotSingleUse);
}
@@ -169,7 +136,7 @@ class OuterProductFusion2Way
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
- return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
+ return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
};
auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -226,8 +193,6 @@ class OuterProductFusion2Way
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
- rewriter.eraseOp(op1);
-
return success();
}
@@ -319,7 +284,7 @@ class OuterProductFusion4Way
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
- return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
+ return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
};
auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -400,10 +365,6 @@ class OuterProductFusion4Way
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
- rewriter.eraseOp(op3);
- rewriter.eraseOp(op2);
- rewriter.eraseOp(op1);
-
return success();
}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 4887d611643fb..9000551783576 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -4,10 +4,10 @@
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
-// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[A0]], %[[A1]] : vector<[4]xf16> -> vector<[8]xf16>
+// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[B0]], %[[B1]] : vector<[4]xf16> -> vector<[8]xf16>
+// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
// CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
func.func @outerproduct_add_widening_2way_f16f16f32(
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
@@ -225,18 +225,18 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
// CHECK-SAME: %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>,
// CHECK-SAME: %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
-// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
-// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
-// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: %[[LHS0:.*]] = vector.interleave %[[A0]], %[[A2]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS1:.*]] = vector.interleave %[[A1]], %[[A3]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS0:.*]] = vector.interleave %[[B0]], %[[B2]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS1:.*]] = vector.interleave %[[B1]], %[[B3]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[LHS0]], %[[LHS1]] : vector<[8]xi8> -> vector<[16]xi8>
+// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[RHS0]], %[[RHS1]] : vector<[8]xi8> -> vector<[16]xi8>
+// CHECK-DAG: %[[LHS0_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS1_MASK:.*]] = vector.interleave %[[A1_MASK]], %[[A3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS0_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS1_MASK:.*]] = vector.interleave %[[B1_MASK]], %[[B3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[LHS0_MASK]], %[[LHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[RHS0_MASK]], %[[RHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
func.func @outerproduct_add_widening_4way_signed_i8i8i32(
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index aabd9d2ce788e..5784ecbbe4014 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -1,11 +1,7 @@
// RUN: mlir-opt %s \
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN: -arm-sme-vector-legalization -canonicalize -cse \
-// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
-// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
-// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
|
Could you list these changes? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left one comment otherwise LGTM cheers
mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
Show resolved
Hide resolved
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/41/builds/1112 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/17/builds/1801 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/143/builds/1277 Here is the relevant piece of the build log for the reference:
|
multi-tile-matmul-mixed-types.mlir
Recent changes: #90448, #80965