Skip to content

Commit 26f705d

Browse files
committed
replace arm_sve.intr.zip1 with target-agnostic interleave2 intrinsic
1 parent 601bbae commit 26f705d

File tree

4 files changed

+16
-40
lines changed

4 files changed

+16
-40
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -958,13 +958,8 @@ def FMopaWide2WayOp
958958
product as follows:
959959

960960
```mlir
961-
%undef = llvm.mlir.undef : vector<[8]xf16>
962-
%a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
963-
%a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
964-
%a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
965-
%b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
966-
%b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
967-
%b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
961+
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
962+
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
968963
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
969964
```
970965

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,8 @@ def OuterProductWidening
143143
Becomes:
144144

145145
```mlir
146-
%undef = llvm.mlir.undef : vector<[8]xf16>
147-
%a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
148-
%a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
149-
%a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
150-
%b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
151-
%b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
152-
%b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
146+
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
147+
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
153148
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
154149
```
155150

mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ namespace {
4848
//
4949
// Becomes:
5050
//
51-
// %a_packed = arm_sve.zip %a0, %a1 : vector<[8]xf16> to vector<[8]xf16>
52-
// %b_packed = arm_sve.zip %b0, %b1 : vector<[8]xf16> to vector<[8]xf16>
53-
// %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>,
54-
// vector<[4]xf32>
51+
// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
52+
// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
53+
// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
54+
// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
55+
// %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed
56+
// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
5557
class OuterProduct2WayWidening
5658
: public OpRewritePattern<arm_sme::OuterProductOp> {
5759
public:
@@ -113,15 +115,9 @@ class OuterProduct2WayWidening
113115

114116
auto loc = op.getLoc();
115117

116-
// zip(lhs, rhs)
117118
auto packInputs = [&](VectorType type, Value lhs, Value rhs) {
118-
auto undef = rewriter.create<LLVM::UndefOp>(loc, type);
119-
auto insertLHS =
120-
rewriter.create<vector::ScalableInsertOp>(loc, lhs, undef, 0);
121-
auto insertRHS =
122-
rewriter.create<vector::ScalableInsertOp>(loc, rhs, undef, 0);
123-
return rewriter.create<arm_sve::Zip1IntrOp>(loc, type, insertLHS,
124-
insertRHS);
119+
return rewriter.create<LLVM::experimental_vector_interleave2>(loc, type,
120+
lhs, rhs);
125121
};
126122

127123
auto extOp = op.getLhs().getDefiningOp();

mlir/test/Dialect/ArmSME/outer-product-widening.mlir

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,10 @@
44
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
55
// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
66
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
7-
// CHECK-DAG: %[[VEC_UNDEF:.*]] = llvm.mlir.undef : vector<[8]xf16>
8-
// CHECK-DAG: %[[A0_INSERT:.*]] = vector.scalable.insert %[[A0]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
9-
// CHECK-DAG: %[[B0_INSERT:.*]] = vector.scalable.insert %[[B0]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
10-
// CHECK-DAG: %[[A1_INSERT:.*]] = vector.scalable.insert %[[A1]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
11-
// CHECK-DAG: %[[B1_INSERT:.*]] = vector.scalable.insert %[[B1]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
12-
// CHECK-DAG: %[[LHS:.*]] = "arm_sve.intr.zip1"(%[[A0_INSERT]], %[[A1_INSERT]]) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
13-
// CHECK-DAG: %[[RHS:.*]] = "arm_sve.intr.zip1"(%[[B0_INSERT]], %[[B1_INSERT]]) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
14-
// CHECK-DAG: %[[MASK_UNDEF:.*]] = llvm.mlir.undef : vector<[8]xi1>
15-
// CHECK-DAG: %[[A0_MASK_INSERT:.*]] = vector.scalable.insert %[[A0_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
16-
// CHECK-DAG: %[[B0_MASK_INSERT:.*]] = vector.scalable.insert %[[B0_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
17-
// CHECK-DAG: %[[A1_MASK_INSERT:.*]] = vector.scalable.insert %[[A1_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
18-
// CHECK-DAG: %[[B1_MASK_INSERT:.*]] = vector.scalable.insert %[[B1_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
19-
// CHECK-DAG: %[[LHS_MASK:.*]] = "arm_sve.intr.zip1"(%[[A0_MASK_INSERT]], %[[A1_MASK_INSERT]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[8]xi1>
20-
// CHECK-DAG: %[[RHS_MASK:.*]] = "arm_sve.intr.zip1"(%[[B0_MASK_INSERT]], %[[B1_MASK_INSERT]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[8]xi1>
7+
// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
8+
// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
9+
// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
10+
// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
2111
// CHECK-DAG: arm_sme.fmopa_wide_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
2212
func.func @outerproduct_add_widening_2way_f16f16f32(
2313
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,

0 commit comments

Comments
 (0)