Skip to content

Commit e0f3a95

Browse files
committed
[mlir][vector] Disallow vector.fma over vectors of integers
This is to make `vector.fma` more consistent with the standard definition of `fma` that is defined only for flaoting point types. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D141711
1 parent afc3756 commit e0f3a95

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,10 @@ def Vector_FMAOp :
630630
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
631631
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
632632
] # ElementwiseMappable.traits>,
633-
Arguments<(ins AnyVectorOfAnyRank:$lhs,
634-
AnyVectorOfAnyRank:$rhs,
635-
AnyVectorOfAnyRank:$acc)>,
636-
Results<(outs AnyVectorOfAnyRank:$result)> {
633+
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
634+
VectorOfAnyRankOf<[AnyFloat]>:$rhs,
635+
VectorOfAnyRankOf<[AnyFloat]>:$acc)>,
636+
Results<(outs VectorOfAnyRankOf<[AnyFloat]>:$result)> {
637637
let summary = "vector fused multiply-add";
638638
let description = [{
639639
Multiply-add expressions operate on n-D vectors and compute a fused

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ func.func @broadcast_unknown(%arg0: memref<4x8xf32>) {
4444

4545
// -----
4646

47+
func.func @fma_vector_4xi32(%arg0: vector<4xi32>) {
48+
// expected-error@+1 {{'vector.fma' op operand #0 must be vector of floating-point value}}
49+
%1 = vector.fma %arg0, %arg0, %arg0 : vector<4xi32>
50+
}
51+
52+
// -----
53+
4754
func.func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
4855
// expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
4956
%1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32>

0 commit comments

Comments
 (0)