Skip to content

Commit 1752740

Browse files
authored
[mlir][tensor] Fix FoldTensorCastProducerOp for multiple result operations (#93374)
For patterns where there are multiple results apart from dpsInits, this fails. E.g.: ``` %13:2 = iree_codegen.ukernel.generic "iree_uk_unpack" ins(%extracted_slice : tensor<?x1x16x16xf32>) outs(%11 : tensor<?x?xf32>) ... -> tensor<?x?xf32>, i32 ``` The above op has results apart from dpsInit and hence fails. The PR assumes that the result has dpsInits followed by nonDpsInits.
1 parent 670fa2b commit 1752740

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4531,17 +4531,18 @@ struct FoldTensorCastProducerOp
45314531
if (!hasTensorCastOperand)
45324532
return failure();
45334533

4534-
SmallVector<Type, 4> newResultTypes;
4535-
newResultTypes.reserve(op->getNumResults());
4534+
SmallVector<Type, 4> newResultTypes(op->getResultTypes());
45364535
SmallVector<Value, 4> newOperands;
45374536
newOperands.reserve(op->getNumOperands());
4537+
// Assumes that the result has dpsInits followed by nonDpsInits.
4538+
int64_t dpsInitIdx = 0;
45384539
for (OpOperand &opOperand : op->getOpOperands()) {
45394540
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
45404541
bool fold = canFoldIntoConsumerOp(tensorCastOp);
45414542
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
45424543
if (op.isDpsInit(&opOperand) &&
45434544
!llvm::isa<MemRefType>(newOperands.back().getType()))
4544-
newResultTypes.push_back(newOperands.back().getType());
4545+
newResultTypes[dpsInitIdx++] = newOperands.back().getType();
45454546
}
45464547

45474548
// Clone op.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,3 +2523,18 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
25232523
%16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
25242524
return %16 : vector<7xi32>
25252525
}
2526+
2527+
// -----
2528+
2529+
// CHECK-LABEL: func.func @test_destination_multiple_result(
2530+
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
2531+
// CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
2532+
// CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
2533+
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
2534+
// CHECK: return %[[RES]]#1 : index
2535+
func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
2536+
%cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
2537+
%cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
2538+
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
2539+
return %0#1 : index
2540+
}

0 commit comments

Comments
 (0)