Skip to content

Commit cee4852

Browse files
CoTinkertomtor
authored andcommitted
[mlir][linalg] Add pure tensor check for winogradConv2DHelper (llvm#142299)
This PR adds pure tensor semantics check for `winogradConv2DHelper` to prevent a crash. Fixes llvm#141566.
1 parent ea70b75 commit cee4852

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
904904
static FailureOr<Operation *>
905905
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
906906
int64_t m, int64_t r) {
907+
if (!convOp.hasPureTensorSemantics())
908+
return rewriter.notifyMatchFailure(
909+
convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
910+
907911
Value input = convOp.getInputs()[0];
908912
Value filter = convOp.getInputs()[1];
909913
Value output = convOp.getOutputs()[0];

mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ module attributes {transform.with_named_sequence} {
6161

6262
// -----
6363

64+
func.func @conv2d_unsupported_type(%arg0: memref<2x10x10x5xf32>, %arg1: memref<2x3x3x5xf32>, %arg2: memref<2x8x8x2xf32>) {
65+
linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : memref<2x10x10x5xf32>, memref<2x3x3x5xf32>) outs(%arg2 : memref<2x8x8x2xf32>)
66+
return
67+
}
68+
69+
module attributes {transform.with_named_sequence} {
70+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
71+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
72+
// expected-error @+1 {{apply Winograd Conv2D failed}}
73+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
74+
transform.yield
75+
}
76+
}
77+
78+
// -----
79+
6480
func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
6581
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
6682
return %0 : tensor<2x?x?x2xf32>

0 commit comments

Comments
 (0)