@@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
501
501
reductionOp.getVector (),
502
502
rewriter.getI64ArrayAttr (0 ));
503
503
504
- if (Value acc = reductionOp.getAcc ()) {
505
- assert (reductionOp.getType ().isa <FloatType>());
506
- switch (reductionOp.getKind ()) {
507
- case CombiningKind::ADD:
508
- result = rewriter.create <arith::AddFOp>(loc, result, acc);
509
- break ;
510
- case CombiningKind::MUL:
511
- result = rewriter.create <arith::MulFOp>(loc, result, acc);
512
- break ;
513
- default :
514
- assert (false && " invalid op!" );
515
- }
516
- }
504
+ if (Value acc = reductionOp.getAcc ())
505
+ result = vector::makeArithReduction (rewriter, loc, reductionOp.getKind (),
506
+ result, acc);
517
507
518
508
rewriter.replaceOp (reductionOp, result);
519
509
return success ();
@@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
5007
4997
verifyDistributedType (lhs, rhs, getWarpSize (), getOperation ()));
5008
4998
}
5009
4999
5000
+ Value mlir::vector::makeArithReduction (OpBuilder &b, Location loc,
5001
+ CombiningKind kind, Value v1, Value v2) {
5002
+ Type t1 = getElementTypeOrSelf (v1.getType ());
5003
+ Type t2 = getElementTypeOrSelf (v2.getType ());
5004
+ switch (kind) {
5005
+ case CombiningKind::ADD:
5006
+ if (t1.isIntOrIndex () && t2.isIntOrIndex ())
5007
+ return b.createOrFold <arith::AddIOp>(loc, v1, v2);
5008
+ else if (t1.isa <FloatType>() && t2.isa <FloatType>())
5009
+ return b.createOrFold <arith::AddFOp>(loc, v1, v2);
5010
+ llvm_unreachable (" invalid value types for ADD reduction" );
5011
+ case CombiningKind::AND:
5012
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5013
+ return b.createOrFold <arith::AndIOp>(loc, v1, v2);
5014
+ case CombiningKind::MAXF:
5015
+ assert (t1.isa <FloatType>() && t2.isa <FloatType>() &&
5016
+ " expected float values" );
5017
+ return b.createOrFold <arith::MaxFOp>(loc, v1, v2);
5018
+ case CombiningKind::MINF:
5019
+ assert (t1.isa <FloatType>() && t2.isa <FloatType>() &&
5020
+ " expected float values" );
5021
+ return b.createOrFold <arith::MinFOp>(loc, v1, v2);
5022
+ case CombiningKind::MAXSI:
5023
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5024
+ return b.createOrFold <arith::MaxSIOp>(loc, v1, v2);
5025
+ case CombiningKind::MINSI:
5026
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5027
+ return b.createOrFold <arith::MinSIOp>(loc, v1, v2);
5028
+ case CombiningKind::MAXUI:
5029
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5030
+ return b.createOrFold <arith::MaxUIOp>(loc, v1, v2);
5031
+ case CombiningKind::MINUI:
5032
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5033
+ return b.createOrFold <arith::MinUIOp>(loc, v1, v2);
5034
+ case CombiningKind::MUL:
5035
+ if (t1.isIntOrIndex () && t2.isIntOrIndex ())
5036
+ return b.createOrFold <arith::MulIOp>(loc, v1, v2);
5037
+ else if (t1.isa <FloatType>() && t2.isa <FloatType>())
5038
+ return b.createOrFold <arith::MulFOp>(loc, v1, v2);
5039
+ llvm_unreachable (" invalid value types for MUL reduction" );
5040
+ case CombiningKind::OR:
5041
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5042
+ return b.createOrFold <arith::OrIOp>(loc, v1, v2);
5043
+ case CombiningKind::XOR:
5044
+ assert (t1.isIntOrIndex () && t2.isIntOrIndex () && " expected int values" );
5045
+ return b.createOrFold <arith::XOrIOp>(loc, v1, v2);
5046
+ };
5047
+ llvm_unreachable (" unknown CombiningKind" );
5048
+ }
5049
+
5010
5050
// ===----------------------------------------------------------------------===//
5011
5051
// TableGen'd op method definitions
5012
5052
// ===----------------------------------------------------------------------===//
0 commit comments