Skip to content

Commit 74a8986

Browse files
committed
Addressing review feedbacks
1 parent 00c3a33 commit 74a8986

File tree

2 files changed

+65
-45
lines changed

2 files changed

+65
-45
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ using namespace mlir::linalg;
5252
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
5353
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
5454

55-
// Forward declaration of Conv1DGenerator and its validator
56-
namespace {
57-
struct Conv1DGenerator;
58-
bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
59-
} // namespace
60-
6155
/// Try to vectorize `convOp` as a convolution.
6256
static FailureOr<Operation *>
6357
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1945,6 +1939,22 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19451939
return success();
19461940
}
19471941

1942+
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
1943+
// We only support 1D convolutions, reject all other cases.
1944+
if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
1945+
linalg::Conv2DNchwFchwOp>(convOp)) {
1946+
LDBG("2D convolutions are not supported\n");
1947+
return failure();
1948+
}
1949+
1950+
if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
1951+
LDBG("3D convolutions are not supported\n");
1952+
return failure();
1953+
}
1954+
1955+
return success();
1956+
}
1957+
19481958
static LogicalResult vectorizeLinalgOpPrecondition(
19491959
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
19501960
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -1996,20 +2006,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
19962006
// TODO: isaConvolutionOpInterface that can also infer from generic
19972007
// features. But we will still need stride/dilation attributes that will be
19982008
// annoying to reverse-engineer...
1999-
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2000-
// Create a dummy rewriter first, a rewriter is not required for
2001-
// validation
2002-
IRRewriter dummyBuilder(linalgOp.getContext());
2003-
// Check if we can successfully construct a 1d convolution generator.
2004-
// For example, if it is 2d+ convolution, return failure because we don't
2005-
// support it. To use this pass on a 2d+ convolution, it should have already
2006-
// been decomposed to 1d convolution via
2007-
// DecomposeConvolutionToLowerDimOpsPass.
2008-
if (!validateConv1DGenerator(dummyBuilder, linalgOp))
2009-
return failure();
2010-
2011-
return success();
2012-
}
2009+
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2010+
return vectorizeConvOpPrecondition(linalgOp);
20132011

20142012
// TODO: the common vector shape is equal to the static loop sizes only when
20152013
// all indexing maps are projected permutations. For convs and stencils the
@@ -3918,34 +3916,28 @@ struct Conv1DGenerator
39183916
}
39193917
}
39203918
};
3921-
3922-
// Helper function to construct Conv1DGenerator
3923-
bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
3924-
Conv1DGenerator conv1dGen(rewriter, linalgOp);
3925-
return conv1dGen.isValid();
3926-
}
3927-
39283919
} // namespace
39293920

39303921
/// Helper function to vectorize a LinalgOp with convolution semantics.
39313922
// TODO: extend the generic vectorization to support windows and drop this.
39323923
static FailureOr<Operation *> vectorizeConvolution(
39333924
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
39343925
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3935-
Conv1DGenerator e(rewriter, op);
3936-
auto res = e.generateNonChanneledConv();
3926+
Conv1DGenerator conv1dGen(rewriter, op);
3927+
assert(conv1dGen.isValid() && "Conv1DGenerator failed");
3928+
auto res = conv1dGen.generateNonChanneledConv();
39373929
if (succeeded(res))
39383930
return res;
3939-
res = e.generateNwcConv();
3931+
res = conv1dGen.generateNwcConv();
39403932
if (succeeded(res))
39413933
return res;
3942-
res = e.generateNcwConv();
3934+
res = conv1dGen.generateNcwConv();
39433935
if (succeeded(res))
39443936
return res;
3945-
res = e.generateNwcPooling();
3937+
res = conv1dGen.generateNwcPooling();
39463938
if (succeeded(res))
39473939
return res;
3948-
res = e.generateNcwPooling();
3940+
res = conv1dGen.generateNcwPooling();
39493941
if (succeeded(res))
39503942
return res;
39513943

@@ -3957,11 +3949,9 @@ static FailureOr<Operation *> vectorizeConvolution(
39573949
if (!inputVecSizes.empty()) {
39583950
// Only use the input vector size corresponding to the channel dim. Other
39593951
// vector dims will be inferred from the Ops.
3960-
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
3961-
!isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
3962-
return rewriter.notifyMatchFailure(
3963-
op, "Unexpected convolution: expected 1D depthwise conv");
3964-
}
3952+
assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3953+
isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3954+
"Not a 1D depthwise conv!");
39653955
size_t chDimIdx =
39663956
TypeSwitch<Operation *, size_t>(op)
39673957
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
@@ -3970,8 +3960,8 @@ static FailureOr<Operation *> vectorizeConvolution(
39703960
vecChDimSize = inputVecSizes[chDimIdx];
39713961
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
39723962
}
3973-
return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3974-
flatten1DDepthwiseConv);
3963+
return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3964+
flatten1DDepthwiseConv);
39753965
}
39763966

39773967
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,9 @@ module attributes {transform.with_named_sequence} {
112112

113113
// -----
114114

115-
func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4: tensor<64x64x3x3xf32>) {
116-
%cst = arith.constant 0.000000e+00 : f32
117-
%5 = tensor.empty() : tensor<1x64x56x56xf32>
118-
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
115+
func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter: tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) {
119116
// expected-error @+1 {{Attempted to vectorize, but failed}}
120-
%7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
117+
linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>
121118
return
122119
}
123120

@@ -131,6 +128,39 @@ module attributes {transform.with_named_sequence} {
131128

132129
// -----
133130

131+
func.func @conv2d_nhwc_fhwc(%input: tensor<1x8x8x5xf32>, %filter: tensor<4x3x3x5xf32>, %output: tensor<1x6x6x4xf32>) {
132+
// expected-error @+1 {{Attempted to vectorize, but failed}}
133+
linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x8x8x5xf32>, tensor<4x3x3x5xf32>) outs(%output : tensor<1x6x6x4xf32>) -> tensor<1x6x6x4xf32>
134+
return
135+
}
136+
137+
138+
module attributes {transform.with_named_sequence} {
139+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
140+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
141+
transform.structured.vectorize %0 : !transform.any_op
142+
transform.yield
143+
}
144+
}
145+
146+
// -----
147+
148+
func.func @conv3d_ncdhw_fcdhw(%input: tensor<1x5x8x8x8xf32>, %filter: tensor<4x5x3x3x3xf32>, %output: tensor<1x4x6x6x6xf32>) {
149+
// expected-error @+1 {{Attempted to vectorize, but failed}}
150+
linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : vector<3xi64>, strides = dense<1> : vector<3xi64>} ins(%input, %filter : tensor<1x5x8x8x8xf32>, tensor<4x5x3x3x3xf32>) outs(%output : tensor<1x4x6x6x6xf32>) -> tensor<1x4x6x6x6xf32>
151+
return
152+
}
153+
154+
module attributes {transform.with_named_sequence} {
155+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
156+
%0 = transform.structured.match ops{["linalg.conv_3d_ncdhw_fcdhw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
157+
transform.structured.vectorize %0 : !transform.any_op
158+
transform.yield
159+
}
160+
}
161+
162+
// -----
163+
134164
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
135165
%pad = arith.constant 0.000000e+00 : f32
136166
// expected-error @+1 {{Attempted to vectorize, but failed}}

0 commit comments

Comments
 (0)