@@ -52,12 +52,6 @@ using namespace mlir::linalg;
52
52
#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
53
53
#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
54
54
55
- // Forward declaration of Conv1DGenerator and its validator
56
- namespace {
57
- struct Conv1DGenerator ;
58
- bool validateConv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp);
59
- } // namespace
60
-
61
55
// / Try to vectorize `convOp` as a convolution.
62
56
static FailureOr<Operation *>
63
57
vectorizeConvolution (RewriterBase &rewriter, LinalgOp convOp,
@@ -1945,6 +1939,22 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1945
1939
return success ();
1946
1940
}
1947
1941
1942
+ static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
1943
+ // We only support 1D convolutions, reject all other cases.
1944
+ if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
1945
+ linalg::Conv2DNchwFchwOp>(convOp)) {
1946
+ LDBG (" 2D convolutions are not supported\n " );
1947
+ return failure ();
1948
+ }
1949
+
1950
+ if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
1951
+ LDBG (" 3D convolutions are not supported\n " );
1952
+ return failure ();
1953
+ }
1954
+
1955
+ return success ();
1956
+ }
1957
+
1948
1958
static LogicalResult vectorizeLinalgOpPrecondition (
1949
1959
LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
1950
1960
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -1996,20 +2006,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
1996
2006
// TODO: isaConvolutionOpInterface that can also infer from generic
1997
2007
// features. But we will still need stride/dilation attributes that will be
1998
2008
// annoying to reverse-engineer...
1999
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
2000
- // Create a dummy rewriter first, a rewriter is not required for
2001
- // validation
2002
- IRRewriter dummyBuilder (linalgOp.getContext ());
2003
- // Check if we can successfully construct a 1d convolution generator.
2004
- // For example, if it is 2d+ convolution, return failure because we don't
2005
- // support it. To use this pass on a 2d+ convolution, it should have already
2006
- // been decomposed to 1d convolution via
2007
- // DecomposeConvolutionToLowerDimOpsPass.
2008
- if (!validateConv1DGenerator (dummyBuilder, linalgOp))
2009
- return failure ();
2010
-
2011
- return success ();
2012
- }
2009
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation ()))
2010
+ return vectorizeConvOpPrecondition (linalgOp);
2013
2011
2014
2012
// TODO: the common vector shape is equal to the static loop sizes only when
2015
2013
// all indexing maps are projected permutations. For convs and stencils the
@@ -3918,34 +3916,28 @@ struct Conv1DGenerator
3918
3916
}
3919
3917
}
3920
3918
};
3921
-
3922
- // Helper function to construct Conv1DGenerator
3923
- bool validateConv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp) {
3924
- Conv1DGenerator conv1dGen (rewriter, linalgOp);
3925
- return conv1dGen.isValid ();
3926
- }
3927
-
3928
3919
} // namespace
3929
3920
3930
3921
// / Helper function to vectorize a LinalgOp with convolution semantics.
3931
3922
// TODO: extend the generic vectorization to support windows and drop this.
3932
3923
static FailureOr<Operation *> vectorizeConvolution (
3933
3924
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t > inputVecSizes,
3934
3925
ArrayRef<bool > inputScalableVecDims, bool flatten1DDepthwiseConv) {
3935
- Conv1DGenerator e (rewriter, op);
3936
- auto res = e.generateNonChanneledConv ();
3926
+ Conv1DGenerator conv1dGen (rewriter, op);
3927
+ assert (conv1dGen.isValid () && " Conv1DGenerator failed" );
3928
+ auto res = conv1dGen.generateNonChanneledConv ();
3937
3929
if (succeeded (res))
3938
3930
return res;
3939
- res = e .generateNwcConv ();
3931
+ res = conv1dGen .generateNwcConv ();
3940
3932
if (succeeded (res))
3941
3933
return res;
3942
- res = e .generateNcwConv ();
3934
+ res = conv1dGen .generateNcwConv ();
3943
3935
if (succeeded (res))
3944
3936
return res;
3945
- res = e .generateNwcPooling ();
3937
+ res = conv1dGen .generateNwcPooling ();
3946
3938
if (succeeded (res))
3947
3939
return res;
3948
- res = e .generateNcwPooling ();
3940
+ res = conv1dGen .generateNcwPooling ();
3949
3941
if (succeeded (res))
3950
3942
return res;
3951
3943
@@ -3957,11 +3949,9 @@ static FailureOr<Operation *> vectorizeConvolution(
3957
3949
if (!inputVecSizes.empty ()) {
3958
3950
// Only use the input vector size corresponding to the channel dim. Other
3959
3951
// vector dims will be inferred from the Ops.
3960
- if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
3961
- !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
3962
- return rewriter.notifyMatchFailure (
3963
- op, " Unexpected convolution: expected 1D depthwise conv" );
3964
- }
3952
+ assert ((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3953
+ isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3954
+ " Not a 1D depthwise conv!" );
3965
3955
size_t chDimIdx =
3966
3956
TypeSwitch<Operation *, size_t >(op)
3967
3957
.Case <linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2 ; })
@@ -3970,8 +3960,8 @@ static FailureOr<Operation *> vectorizeConvolution(
3970
3960
vecChDimSize = inputVecSizes[chDimIdx];
3971
3961
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3972
3962
}
3973
- return e .generateDilatedConv (vecChDimSize, vecChDimScalableFlag,
3974
- flatten1DDepthwiseConv);
3963
+ return conv1dGen .generateDilatedConv (vecChDimSize, vecChDimScalableFlag,
3964
+ flatten1DDepthwiseConv);
3975
3965
}
3976
3966
3977
3967
struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
0 commit comments