Skip to content

Commit 5f8cefe

Browse files
committed
[mlir][vector] Fix crash in vector.reduction canonicalization
since vector.reduce support accumulator in all the cases remove the assert assuming old definition. Differential Revision: https://reviews.llvm.org/D129602
1 parent cc7d966 commit 5f8cefe

File tree

5 files changed

+70
-68
lines changed

5 files changed

+70
-68
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
182182
/// memory.
183183
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
184184
VectorTransferOpInterface transferB);
185+
186+
/// Return the result value of reducing two scalar/vector values with the
187+
/// corresponding arith operation.
188+
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
189+
Value v1, Value v2);
185190
} // namespace vector
186191
} // namespace mlir
187192

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ namespace vector {
3434
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
3535
/// the type of `source`.
3636
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
37-
38-
/// Return the result value of reducing two scalar/vector values with the
39-
/// corresponding arith operation.
40-
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
41-
Value v1, Value v2);
4237
} // namespace vector
4338

4439
/// Return the number of elements of basis, `0` if empty.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
501501
reductionOp.getVector(),
502502
rewriter.getI64ArrayAttr(0));
503503

504-
if (Value acc = reductionOp.getAcc()) {
505-
assert(reductionOp.getType().isa<FloatType>());
506-
switch (reductionOp.getKind()) {
507-
case CombiningKind::ADD:
508-
result = rewriter.create<arith::AddFOp>(loc, result, acc);
509-
break;
510-
case CombiningKind::MUL:
511-
result = rewriter.create<arith::MulFOp>(loc, result, acc);
512-
break;
513-
default:
514-
assert(false && "invalid op!");
515-
}
516-
}
504+
if (Value acc = reductionOp.getAcc())
505+
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
506+
result, acc);
517507

518508
rewriter.replaceOp(reductionOp, result);
519509
return success();
@@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
50074997
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
50084998
}
50094999

5000+
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
5001+
CombiningKind kind, Value v1, Value v2) {
5002+
Type t1 = getElementTypeOrSelf(v1.getType());
5003+
Type t2 = getElementTypeOrSelf(v2.getType());
5004+
switch (kind) {
5005+
case CombiningKind::ADD:
5006+
if (t1.isIntOrIndex() && t2.isIntOrIndex())
5007+
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
5008+
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5009+
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
5010+
llvm_unreachable("invalid value types for ADD reduction");
5011+
case CombiningKind::AND:
5012+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5013+
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
5014+
case CombiningKind::MAXF:
5015+
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
5016+
"expected float values");
5017+
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
5018+
case CombiningKind::MINF:
5019+
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
5020+
"expected float values");
5021+
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
5022+
case CombiningKind::MAXSI:
5023+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5024+
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
5025+
case CombiningKind::MINSI:
5026+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5027+
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
5028+
case CombiningKind::MAXUI:
5029+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5030+
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
5031+
case CombiningKind::MINUI:
5032+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5033+
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
5034+
case CombiningKind::MUL:
5035+
if (t1.isIntOrIndex() && t2.isIntOrIndex())
5036+
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
5037+
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5038+
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
5039+
llvm_unreachable("invalid value types for MUL reduction");
5040+
case CombiningKind::OR:
5041+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5042+
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
5043+
case CombiningKind::XOR:
5044+
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5045+
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
5046+
};
5047+
llvm_unreachable("unknown CombiningKind");
5048+
}
5049+
50105050
//===----------------------------------------------------------------------===//
50115051
// TableGen'd op method definitions
50125052
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
4343
llvm_unreachable("Expected MemRefType or TensorType");
4444
}
4545

46-
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
47-
CombiningKind kind, Value v1, Value v2) {
48-
Type t1 = getElementTypeOrSelf(v1.getType());
49-
Type t2 = getElementTypeOrSelf(v2.getType());
50-
switch (kind) {
51-
case CombiningKind::ADD:
52-
if (t1.isIntOrIndex() && t2.isIntOrIndex())
53-
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
54-
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
55-
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
56-
llvm_unreachable("invalid value types for ADD reduction");
57-
case CombiningKind::AND:
58-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
59-
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
60-
case CombiningKind::MAXF:
61-
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
62-
"expected float values");
63-
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
64-
case CombiningKind::MINF:
65-
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
66-
"expected float values");
67-
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
68-
case CombiningKind::MAXSI:
69-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
70-
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
71-
case CombiningKind::MINSI:
72-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
73-
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
74-
case CombiningKind::MAXUI:
75-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
76-
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
77-
case CombiningKind::MINUI:
78-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
79-
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
80-
case CombiningKind::MUL:
81-
if (t1.isIntOrIndex() && t2.isIntOrIndex())
82-
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
83-
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
84-
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
85-
llvm_unreachable("invalid value types for MUL reduction");
86-
case CombiningKind::OR:
87-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
88-
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
89-
case CombiningKind::XOR:
90-
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
91-
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
92-
};
93-
llvm_unreachable("unknown CombiningKind");
94-
}
95-
9646
/// Return the number of elements of basis, `0` if empty.
9747
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
9848
if (basis.empty())

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
16191619

16201620
// -----
16211621

1622+
// CHECK-LABEL: func @reduce_one_element_vector_maxf
1623+
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
1624+
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
1625+
// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
1626+
// CHECK: return %[[S]]
1627+
func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
1628+
%s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
1629+
return %s : f32
1630+
}
1631+
1632+
// -----
1633+
16221634
// CHECK-LABEL: func @bitcast(
16231635
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
16241636
// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>

0 commit comments

Comments
 (0)