@@ -392,9 +392,11 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
392
392
// / Note: This function does not erase the operation on a successful fold.
393
393
LogicalResult OpBuilder::tryFold (Operation *op,
394
394
SmallVectorImpl<Value> &results) {
395
- results.reserve (op->getNumResults ());
395
+ ResultRange opResults = op->getResults ();
396
+
397
+ results.reserve (opResults.size ());
396
398
auto cleanupFailure = [&] {
397
- results.assign (op-> result_begin (), op-> result_end ());
399
+ results.assign (opResults. begin (), opResults. end ());
398
400
return failure ();
399
401
};
400
402
@@ -405,7 +407,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
405
407
// Check to see if any operands to the operation is constant and whether
406
408
// the operation knows how to constant fold itself.
407
409
SmallVector<Attribute, 4 > constOperands (op->getNumOperands ());
408
- for (unsigned i = 0 , e = op-> getNumOperands (); i != e; ++i)
410
+ for (unsigned i = 0 , e = constOperands. size (); i != e; ++i)
409
411
matchPattern (op->getOperand (i), m_Constant (&constOperands[i]));
410
412
411
413
// Try to fold the operation.
@@ -419,9 +421,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
419
421
420
422
// Populate the results with the folded results.
421
423
Dialect *dialect = op->getDialect ();
422
- for (auto &it : llvm::enumerate (foldResults)) {
424
+ for (auto it : llvm::zip (foldResults, opResults.getTypes ())) {
425
+ Type expectedType = std::get<1 >(it);
426
+
423
427
// Normal values get pushed back directly.
424
- if (auto value = it.value ().dyn_cast <Value>()) {
428
+ if (auto value = std::get<0 >(it).dyn_cast <Value>()) {
429
+ if (value.getType () != expectedType)
430
+ return cleanupFailure ();
431
+
425
432
results.push_back (value);
426
433
continue ;
427
434
}
@@ -431,9 +438,9 @@ LogicalResult OpBuilder::tryFold(Operation *op,
431
438
return cleanupFailure ();
432
439
433
440
// Ask the dialect to materialize a constant operation for this value.
434
- Attribute attr = it. value ( ).get <Attribute>();
435
- auto *constOp = dialect->materializeConstant (
436
- cstBuilder, attr, op-> getResult (it. index ()). getType (), op->getLoc ());
441
+ Attribute attr = std::get< 0 >(it ).get <Attribute>();
442
+ auto *constOp = dialect->materializeConstant (cstBuilder, attr, expectedType,
443
+ op->getLoc ());
437
444
if (!constOp) {
438
445
// Erase any generated constants.
439
446
for (Operation *cst : generatedConstants)
0 commit comments