Skip to content

Commit a528cee

Browse files
authored
[mlir][vector] Improve makeArithReduction expansion (#75846)
Propagate fast math flags. Distinguish `minf`/`maxf` and `minimumf`/`maximumf`. Required for future patterns in #75727.
1 parent de5c49f commit a528cee

File tree

6 files changed

+46
-20
lines changed

6 files changed

+46
-20
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
123123
VectorTransferOpInterface transferB,
124124
bool testDynamicValueUsingBounds = false);
125125

126-
/// Return the result value of reducing two scalar/vector values with the
126+
/// Returns the result value of reducing two scalar/vector values with the
127127
/// corresponding arith operation.
128128
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
129-
Value v1, Value acc, Value mask = Value());
129+
Value v1, Value acc,
130+
arith::FastMathFlagsAttr fastmath = nullptr,
131+
Value mask = nullptr);
130132

131133
/// Returns true if `attr` has "parallel" iterator type semantics.
132134
inline bool isParallelIterator(Attribute attr) {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
507507
zeroIdx);
508508
}
509509

510-
Value result = vector::makeArithReduction(
511-
rewriter, loc, reductionOp.getKind(), acc, cast, mask);
510+
Value result =
511+
vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
512+
cast, /*fastmath=*/nullptr, mask);
512513
rewriter.replaceOp(rootOp, result);
513514
return success();
514515
}
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
650651

651652
if (Value acc = reductionOp.getAcc())
652653
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
653-
result, acc, mask);
654+
result, acc,
655+
reductionOp.getFastmathAttr(), mask);
654656

655657
rewriter.replaceOp(rootOp, result);
656658
return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
62126214

62136215
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
62146216
CombiningKind kind, Value v1, Value acc,
6217+
arith::FastMathFlagsAttr fastmath,
62156218
Value mask) {
62166219
Type t1 = getElementTypeOrSelf(v1.getType());
62176220
Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
62226225
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
62236226
result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
62246227
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6225-
result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
6228+
result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
62266229
else
62276230
llvm_unreachable("invalid value types for ADD reduction");
62286231
break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
62316234
result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
62326235
break;
62336236
case CombiningKind::MAXF:
6237+
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6238+
"expected float values");
6239+
result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6240+
break;
62346241
case CombiningKind::MAXIMUMF:
62356242
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
62366243
"expected float values");
6237-
result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
6244+
result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
62386245
break;
62396246
case CombiningKind::MINF:
6247+
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6248+
"expected float values");
6249+
result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6250+
break;
62406251
case CombiningKind::MINIMUMF:
62416252
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
62426253
"expected float values");
6243-
result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
6254+
result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
62446255
break;
62456256
case CombiningKind::MAXSI:
62466257
assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
62626273
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
62636274
result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
62646275
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6265-
result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
6276+
result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
62666277
else
62676278
llvm_unreachable("invalid value types for MUL reduction");
62686279
break;

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
167167
if (!acc)
168168
return std::optional<Value>(mul);
169169

170-
return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
170+
return makeArithReduction(rewriter, loc, kind, mul, acc,
171+
/*fastmath=*/nullptr, mask);
171172
}
172173

173174
/// Return the positions of the reductions in the given map.

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
450450
// CHECK-LABEL: func.func @masked_float_max_outerprod(
451451
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
452452
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
453-
// CHECK: %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
453+
// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
454454
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
455455

456456
// -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
463463
// CHECK-LABEL: func.func @masked_float_min_outerprod(
464464
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
465465
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
466-
// CHECK: %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
466+
// CHECK: %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
467467
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
468468

469469
// -----

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
21722172

21732173
// -----
21742174

2175+
// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
2176+
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
2177+
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
2178+
// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
2179+
// CHECK: return %[[S]]
2180+
func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
2181+
%s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
2182+
return %s : f32
2183+
}
2184+
2185+
// -----
2186+
21752187
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
21762188
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
21772189
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)

mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
2727
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
2828
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
2929
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
30-
// CHECK: %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
30+
// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
3131
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
32-
// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
32+
// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
3333
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
34-
// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
34+
// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
3535
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
36-
// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
36+
// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
3737
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
3838

3939
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
4545
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
4646
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
4747
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
48-
// CHECK: %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
48+
// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
4949
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
50-
// CHECK: %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
50+
// CHECK: %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
5151
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
52-
// CHECK: %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
52+
// CHECK: %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
5353
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
54-
// CHECK: %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
54+
// CHECK: %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
5555
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
5656

5757
func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {

0 commit comments

Comments
 (0)