Skip to content

Commit 8a5f33f

Browse files
authored
[mlir][ArmSME] Update OuterProductFusion to account for recent changes (#102125)
- Use vector.interleave rather than the LLVM intrinsic - Remove dependency on LLVM dialect - Remove manual outerproduct erases (these are now trivially dead) - Remove comment explaining issues with previous tile allocator - Update pipeline in `multi-tile-matmul-mixed-types.mlir` Recent changes: #90448, #80965
1 parent b74182e commit 8a5f33f

File tree

6 files changed

+33
-77
lines changed

6 files changed

+33
-77
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -910,11 +910,11 @@ def FMopa2WayOp
910910
The 2 outer products in the example above can be fused into a single outer
911911
product as follows:
912912

913-
```mlir
914-
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
915-
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
913+
```mlir
914+
%a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
915+
%b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
916916
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
917-
```
917+
```
918918

919919
This is implemented in the `-arm-sme-outer-product-fusion` pass.
920920

@@ -1167,13 +1167,13 @@ def SMopa4WayOp
11671167
product as follows:
11681168

11691169
```mlir
1170-
%lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1171-
%lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1172-
%lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
1170+
%lhs0 = vector.interleave %a0, %a2 : vector<[4]xi8> -> vector<[8]xi8>
1171+
%lhs1 = vector.interleave %a1, %a3 : vector<[4]xi8> -> vector<[8]xi8>
1172+
%lhs = vector.interleave %lhs0, %lhs1 : vector<[8]xi8> -> vector<[16]xi8>
11731173

1174-
%rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1175-
%rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
1176-
%rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
1174+
%rhs0 = vector.interleave %b0, %b2 : vector<[4]xi8> -> vector<[8]xi8>
1175+
%rhs1 = vector.interleave %b1, %b3 : vector<[4]xi8> -> vector<[8]xi8>
1176+
%rhs = vector.interleave %rhs0, %rhs1 : vector<[8]xi8> -> vector<[16]xi8>
11771177

11781178
%0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
11791179
```

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def OuterProductFusion
180180
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
181181
}];
182182
let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
183-
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
183+
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
184184
}
185185

186186
def VectorLegalization

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
1414
MLIRPass
1515
MLIRArmSMEDialect
1616
MLIRFuncDialect
17-
MLIRLLVMCommonConversion
1817
MLIRVectorDialect
1918
MLIRIndexDialect
2019
MLIRSCFDialect

mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
1616
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
18-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1918
#include "mlir/IR/PatternMatch.h"
2019
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2120
#include "llvm/ADT/TypeSwitch.h"
@@ -80,15 +79,6 @@ static LogicalResult isCompatible(PatternRewriter &rewriter,
8079
return success();
8180
}
8281

83-
// Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
84-
static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
85-
Value lhs, Value rhs) {
86-
auto inputType = cast<VectorType>(lhs.getType());
87-
VectorType inputTypeX2 =
88-
VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
89-
return rewriter.create<LLVM::vector_interleave2>(loc, inputTypeX2, lhs, rhs);
90-
}
91-
9282
// Fuse two 'arm_sme.outerproduct' operations that are chained via the
9383
// accumulator into 2-way outer product operation.
9484
//
@@ -106,10 +96,8 @@ static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
10696
//
10797
// Becomes:
10898
//
109-
// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
110-
// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
111-
// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
112-
// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
99+
// %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
100+
// %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
113101
// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
114102
// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
115103
class OuterProductFusion2Way
@@ -135,28 +123,7 @@ class OuterProductFusion2Way
135123

136124
if (!op1->hasOneUse()) {
137125
// If the first outer product has uses other than as the input to another
138-
// outer product, it can't be erased after fusion. This is a problem when
139-
// it also has an accumulator as this will be used as the root for tile
140-
// allocation and since the widening outer product uses the same
141-
// accumulator it will get assigned the same tile ID, resulting in 3
142-
// outer products accumulating to the same tile and incorrect results.
143-
//
144-
// Example:
145-
//
146-
// %acc = arith.constant dense<0.0> ; root for tile allocation
147-
// %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
148-
// vector.print %0 ; intermediary use, can't erase %0
149-
// %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
150-
//
151-
// After fusion and tile allocation
152-
//
153-
// %0 = arm_sme.zero {tile_id = 0 : i32}
154-
// %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
155-
// vector.print %1
156-
// %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
157-
//
158-
// No accumulator would be ok, but it's simpler to prevent this
159-
// altogether, since it has no benefit.
126+
// outer product, it can't be erased after fusion.
160127
return rewriter.notifyMatchFailure(op,
161128
kMatchFailureOuterProductNotSingleUse);
162129
}
@@ -169,7 +136,7 @@ class OuterProductFusion2Way
169136

170137
auto loc = op.getLoc();
171138
auto packInputs = [&](Value lhs, Value rhs) {
172-
return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
139+
return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
173140
};
174141

175142
auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -226,8 +193,6 @@ class OuterProductFusion2Way
226193
llvm_unreachable("unexpected arm_sme::CombiningKind!");
227194
}
228195

229-
rewriter.eraseOp(op1);
230-
231196
return success();
232197
}
233198

@@ -319,7 +284,7 @@ class OuterProductFusion4Way
319284

320285
auto loc = op.getLoc();
321286
auto packInputs = [&](Value lhs, Value rhs) {
322-
return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
287+
return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
323288
};
324289

325290
auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -400,10 +365,6 @@ class OuterProductFusion4Way
400365
llvm_unreachable("unexpected arm_sme::CombiningKind!");
401366
}
402367

403-
rewriter.eraseOp(op3);
404-
rewriter.eraseOp(op2);
405-
rewriter.eraseOp(op1);
406-
407368
return success();
408369
}
409370

mlir/test/Dialect/ArmSME/outer-product-fusion.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
55
// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
66
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
7-
// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
8-
// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
9-
// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
10-
// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
7+
// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[A0]], %[[A1]] : vector<[4]xf16> -> vector<[8]xf16>
8+
// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[B0]], %[[B1]] : vector<[4]xf16> -> vector<[8]xf16>
9+
// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
10+
// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
1111
// CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
1212
func.func @outerproduct_add_widening_2way_f16f16f32(
1313
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
@@ -225,18 +225,18 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
225225
// CHECK-SAME: %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>,
226226
// CHECK-SAME: %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1>
227227
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
228-
// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
229-
// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
230-
// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
231-
// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
232-
// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
233-
// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
234-
// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
235-
// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
236-
// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
237-
// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
238-
// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
239-
// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
228+
// CHECK-DAG: %[[LHS0:.*]] = vector.interleave %[[A0]], %[[A2]] : vector<[4]xi8> -> vector<[8]xi8>
229+
// CHECK-DAG: %[[LHS1:.*]] = vector.interleave %[[A1]], %[[A3]] : vector<[4]xi8> -> vector<[8]xi8>
230+
// CHECK-DAG: %[[RHS0:.*]] = vector.interleave %[[B0]], %[[B2]] : vector<[4]xi8> -> vector<[8]xi8>
231+
// CHECK-DAG: %[[RHS1:.*]] = vector.interleave %[[B1]], %[[B3]] : vector<[4]xi8> -> vector<[8]xi8>
232+
// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[LHS0]], %[[LHS1]] : vector<[8]xi8> -> vector<[16]xi8>
233+
// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[RHS0]], %[[RHS1]] : vector<[8]xi8> -> vector<[16]xi8>
234+
// CHECK-DAG: %[[LHS0_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
235+
// CHECK-DAG: %[[LHS1_MASK:.*]] = vector.interleave %[[A1_MASK]], %[[A3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
236+
// CHECK-DAG: %[[RHS0_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
237+
// CHECK-DAG: %[[RHS1_MASK:.*]] = vector.interleave %[[B1_MASK]], %[[B3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
238+
// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[LHS0_MASK]], %[[LHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
239+
// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[RHS0_MASK]], %[[RHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
240240
// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
241241
func.func @outerproduct_add_widening_4way_signed_i8i8i32(
242242
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,

mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
// RUN: mlir-opt %s \
22
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
33
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
4-
// RUN: -arm-sme-vector-legalization -canonicalize -cse \
5-
// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
6-
// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
7-
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
8-
// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
4+
// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
95
// RUN: -test-lower-to-llvm | \
106
// RUN: %mcr_aarch64_cmd \
117
// RUN: -e=main -entry-point-result=void \

0 commit comments

Comments
 (0)