Skip to content

Commit aeadd84

Browse files
committed
[mlir][tensor] Fix bug when having multiple result
For patterns where there are multiple results apart from dpsInits this fails. For eg: ``` %13:2 = iree_codegen.ukernel.generic "iree_uk_unpack" ins(%extracted_slice : tensor<?x1x16x16xf32>) outs(%11 : tensor<?x16xf32>) .. ``` The above op has results apart from dpsInit and hence fails. The PR assumes that the result has dpsInits followed by nondpsInits.
1 parent 8364659 commit aeadd84

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4531,17 +4531,17 @@ struct FoldTensorCastProducerOp
45314531
if (!hasTensorCastOperand)
45324532
return failure();
45334533

4534-
SmallVector<Type, 4> newResultTypes;
4535-
newResultTypes.reserve(op->getNumResults());
4534+
SmallVector<Type, 4> newResultTypes(op->getResultTypes());
45364535
SmallVector<Value, 4> newOperands;
45374536
newOperands.reserve(op->getNumOperands());
4537+
int64_t dpsInitIdx = 0;
45384538
for (OpOperand &opOperand : op->getOpOperands()) {
45394539
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
45404540
bool fold = canFoldIntoConsumerOp(tensorCastOp);
45414541
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
45424542
if (op.isDpsInit(&opOperand) &&
45434543
!llvm::isa<MemRefType>(newOperands.back().getType()))
4544-
newResultTypes.push_back(newOperands.back().getType());
4544+
newResultTypes[dpsInitIdx++] = newOperands.back().getType();
45454545
}
45464546

45474547
// Clone op.

0 commit comments

Comments
 (0)