Skip to content

Commit a226542

Browse files
[MLIR][ArmNeon] Add an ArmNeon operation which maps to bfmmla (#145038)
1 parent 7a5af4f commit a226542

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,34 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
222222
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
223223
}
224224

225+
def BfmmlaOp : ArmNeon_IntrOp<"bfmmla", [], [], 1, [
226+
Pure,
227+
AllTypesMatch<["src1", "src2"]>,
228+
AllTypesMatch<["acc", "res"]>,
229+
]> {
230+
let summary = "BFloat16 matrix multiply-accumulate to single-precision";
231+
let description = [{
232+
BFMMLA: BFloat16 matrix multiply-accumulate to single-precision.
233+
234+
The operation multiplies the 2x4 BFloat16 matrix in the first source vector
235+
with the 4x2 BFloat16 matrix in the second source vector, then accumulates
236+
this intermediate result with the 2x2 Float32 matrix in the accumulator
237+
vector, yielding the final 2x2 Float32 result.
238+
239+
Source:
240+
https://developer.arm.com/architectures/instruction-sets/intrinsics/vbfmmlaq_f32
241+
}];
242+
// Supports (vector<8xbf16>, vector<8xbf16>) -> (vector<2xf32>)
243+
let arguments = (ins
244+
NeonVectorOfLength<4, F32>:$acc,
245+
NeonVectorOfLength<8, BF16>:$src1,
246+
NeonVectorOfLength<8, BF16>:$src2
247+
);
248+
let results = (outs NeonVectorOfLength<4, F32>:$res);
249+
let assemblyFormat =
250+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
251+
}
252+
225253
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
226254
: Op</*dialect=*/ArmNeon_Dialect,
227255
/*opName=*/"2d." # mnemonic,

mlir/test/Dialect/ArmNeon/invalid.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,43 @@ func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
9191
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
9292
return %0 : vector<8xi32>
9393
}
94+
95+
// -----
96+
97+
func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<4xf32>,
98+
%lhs: vector<8xf16>,
99+
%rhs: vector<8xf16>) -> vector<4xf32> {
100+
// expected-error@+1 {{operand #1 must be a vector with length 8 of bfloat16 type values, but got 'vector<8xf16>'}}
101+
%0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<8xf16> to vector<4xf32>
102+
return %0 : vector<4xf32>
103+
}
104+
105+
// -----
106+
107+
func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<4xf32>,
108+
%lhs: vector<4xbf16>,
109+
%rhs: vector<4xbf16>) -> vector<4xf32> {
110+
// expected-error@+1 {{operand #1 must be a vector with length 8 of bfloat16 type values, but got 'vector<4xbf16>'}}
111+
%0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<4xbf16> to vector<4xf32>
112+
return %0 : vector<4xf32>
113+
}
114+
115+
// -----
116+
117+
func.func @bfmmla_invalid_element_type_acc(%acc: vector<4xi32>,
118+
%lhs: vector<8xbf16>,
119+
%rhs: vector<8xbf16>) -> vector<4xi32> {
120+
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit float values, but got 'vector<4xi32>'}}
121+
%0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<4xi32>
122+
return %0 : vector<4xi32>
123+
}
124+
125+
// -----
126+
127+
func.func @bfmmla_invalid_dimension_acc(%acc: vector<8xf32>,
128+
%lhs: vector<8xbf16>,
129+
%rhs: vector<8xbf16>) -> vector<8xf32> {
130+
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit float values, but got 'vector<8xf32>'}}
131+
%0 = arm_neon.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<8xf32>
132+
return %0 : vector<8xf32>
133+
}

mlir/test/Dialect/ArmNeon/roundtrip.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,15 @@ func.func @arm_neon_usmmla(%a: vector<16xi8>,
6060
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
6161
return %0 : vector<4xi32>
6262
}
63+
64+
65+
// -----
66+
67+
// CHECK-LABEL: arm_neon_bfmmla
68+
func.func @arm_neon_bfmmla(%a: vector<8xbf16>,
69+
%b: vector<8xbf16>,
70+
%c: vector<4xf32>) -> vector<4xf32> {
71+
// CHECK: arm_neon.intr.bfmmla {{.*}}: vector<8xbf16> to vector<4xf32>
72+
%0 = arm_neon.intr.bfmmla %c, %a, %b : vector<8xbf16> to vector<4xf32>
73+
return %0 : vector<4xf32>
74+
}

mlir/test/Target/LLVMIR/arm-neon.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,16 @@ llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
8282
-> vector<4xi32>
8383
llvm.return %0 : vector<4xi32>
8484
}
85+
86+
// -----
87+
88+
// CHECK-LABEL: arm_neon_bfmmla
89+
llvm.func @arm_neon_bfmmla(%arg0: vector<8xbf16>,
90+
%arg1: vector<8xbf16>,
91+
%arg2: vector<4xf32>) -> vector<4xf32> {
92+
// CHECK: <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float
93+
%0 = "arm_neon.intr.bfmmla"(%arg2, %arg0, %arg1) :
94+
(vector<4xf32>, vector<8xbf16>, vector<8xbf16>)
95+
-> vector<4xf32>
96+
llvm.return %0 : vector<4xf32>
97+
}

0 commit comments

Comments
 (0)