Skip to content

Commit 534638e

Browse files
[mlir][linalg] Fix crash in canonicalization pattern
This crash was due to incorrect usage of `hasTensorSemantics`, which has changed recently with DestinationStyleOpInterface. An op has tensor semantics if all of its inits and inputs are tensors. Previously, only inits needed to be tensors. Differential Revision: https://reviews.llvm.org/D137243
1 parent dd927f4 commit 534638e

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -935,10 +935,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
935935
// Create the new op with the body being empty.
936936
Location loc = genericOp.getLoc();
937937
SmallVector<Type> newResultTypes;
938-
if (genericOp.hasTensorSemantics()) {
939-
newResultTypes = llvm::to_vector(llvm::map_range(
940-
newOutputOperands, [](Value v) { return v.getType(); }));
941-
}
938+
for (Value v : newOutputOperands)
939+
if (v.getType().isa<TensorType>())
940+
newResultTypes.push_back(v.getType());
942941
auto newOp = rewriter.create<GenericOp>(
943942
loc, newResultTypes, newInputOperands, newOutputOperands,
944943
rewriter.getAffineMapArrayAttr(newIndexingMaps),

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,3 +846,27 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
846846
// CHECK-SAME: iterator_types = ["parallel"]
847847
// CHECK-SAME: } ins(%[[ARG1]] : tensor<?xf32>)
848848
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
849+
850+
// -----
851+
852+
// Just make sure that we don't crash.
853+
854+
// CHECK-LABEL: func @dedeplicate_regression_test
855+
func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
856+
%36 = linalg.generic
857+
{indexing_maps = [affine_map<(d0) -> (d0)>,
858+
affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
859+
iterator_types = ["parallel"]}
860+
ins(%1, %1 : memref<4xf32>, memref<4xf32>)
861+
outs(%0 : tensor<4xf32>) {
862+
^bb0(%in: f32, %in_24: f32, %out: f32):
863+
linalg.yield %in : f32
864+
} -> tensor<4xf32>
865+
%53 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
866+
iterator_types = ["parallel"]}
867+
outs(%36 : tensor<4xf32>) {
868+
^bb0(%out: f32):
869+
linalg.yield %out : f32
870+
} -> tensor<4xf32>
871+
return
872+
}

0 commit comments

Comments
 (0)