@@ -1939,19 +1939,124 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1939
1939
return success ();
1940
1940
}
1941
1941
1942
+ namespace {
1943
+ bool isCastOfBlockArgument (Operation *op) {
1944
+ return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
1945
+ isa<BlockArgument>(op->getOperand (0 ));
1946
+ }
1947
+
1948
+ // Returns true iff it is a valid conv/pooling op.
1949
+ // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
1950
+ // + yield) and rhs is not used) then it is the body of a pooling
1951
+ // If conv, check for single `mul` predecessor. The `mul` operands must be
1952
+ // block arguments or extension of block arguments.
1953
+ // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
1954
+ // must be block arguments or extension of block arguments.
1955
+ enum OperKind { Conv, Pool };
1956
+ bool getOperKind (Operation *reduceOp, OperKind &oper) {
1957
+ int numBlockArguments =
1958
+ llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
1959
+
1960
+ switch (numBlockArguments) {
1961
+ case 1 : {
1962
+ // Will be convolution if feeder is a MulOp.
1963
+ // A strength reduced version of MulOp for i1 type is AndOp which is also
1964
+ // supported. Otherwise, it can be pooling. This strength reduction logic
1965
+ // is in `buildBinaryFn` helper in the Linalg dialect.
1966
+ auto feedValIt = llvm::find_if_not (reduceOp->getOperands (),
1967
+ llvm::IsaPred<BlockArgument>);
1968
+ Operation *feedOp = (*feedValIt).getDefiningOp ();
1969
+ // llvm::outs() << "feedOp: " << *feedOp << "\n";
1970
+ if (isCastOfBlockArgument (feedOp)) {
1971
+ oper = Pool;
1972
+ } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1973
+ (isa<arith::AndIOp>(feedOp) &&
1974
+ feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
1975
+ llvm::all_of (feedOp->getOperands (), [](Value v) {
1976
+ if (isa<BlockArgument>(v))
1977
+ return true ;
1978
+ if (Operation *op = v.getDefiningOp ())
1979
+ return isCastOfBlockArgument (op);
1980
+ return false ;
1981
+ }))) {
1982
+ return false ;
1983
+ }
1984
+ return true ;
1985
+ }
1986
+ case 2 :
1987
+ // Must be pooling
1988
+ oper = Pool;
1989
+ return true ;
1990
+ default :
1991
+ return false ;
1992
+ }
1993
+ }
1994
+
1995
+ bool isSupportedPoolKind (vector::CombiningKind kind) {
1996
+ switch (kind) {
1997
+ case vector::CombiningKind::ADD:
1998
+ case vector::CombiningKind::MAXNUMF:
1999
+ case vector::CombiningKind::MAXIMUMF:
2000
+ case vector::CombiningKind::MAXSI:
2001
+ case vector::CombiningKind::MAXUI:
2002
+ case vector::CombiningKind::MINNUMF:
2003
+ case vector::CombiningKind::MINIMUMF:
2004
+ case vector::CombiningKind::MINSI:
2005
+ case vector::CombiningKind::MINUI:
2006
+ return true ;
2007
+ default :
2008
+ return false ;
2009
+ }
2010
+ }
2011
+ } // namespace
2012
+
1942
2013
static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
1943
- // We only support 1D convolutions, reject all other cases.
1944
- if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
1945
- linalg::Conv2DNchwFchwOp>(convOp)) {
1946
- LDBG (" 2D convolutions are not supported\n " );
2014
+ if (convOp.getNumDpsInputs () != 2 || convOp.getNumDpsInits () != 1 )
2015
+ return failure ();
2016
+
2017
+ auto lhsShaped = convOp.getDpsInputOperand (0 )->get ();
2018
+ auto rhsShaped = convOp.getDpsInputOperand (1 )->get ();
2019
+ auto resShaped = convOp.getDpsInitOperand (0 )->get ();
2020
+ auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType ());
2021
+ auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType ());
2022
+ auto resShapedType = dyn_cast<ShapedType>(resShaped.getType ());
2023
+ if (!lhsShapedType || !rhsShapedType || !resShapedType)
2024
+ return failure ();
2025
+ // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2026
+ // (non-channeled convolution -> LHS and RHS both have single dimensions).
2027
+ if ((lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 ) &&
2028
+ (lhsShapedType.getRank () != 1 || resShapedType.getRank () != 1 ))
1947
2029
return failure ();
1948
- }
1949
2030
1950
- if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
1951
- LDBG (" 3D convolutions are not supported\n " );
2031
+ Operation *reduceOp = matchLinalgReduction (convOp.getDpsInitOperand (0 ));
2032
+ if (!reduceOp)
2033
+ return failure ();
2034
+
2035
+ OperKind oper = Conv;
2036
+ if (!getOperKind (reduceOp, oper))
2037
+ return failure ();
2038
+ auto maybeKind = getCombinerOpKind (reduceOp);
2039
+ // Typically convolution will have a `Add` CombiningKind but for i1 type it
2040
+ // can get strength reduced to `OR` which is also supported. This strength
2041
+ // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2042
+ if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2043
+ *maybeKind != vector::CombiningKind::OR) &&
2044
+ (oper != Pool || !isSupportedPoolKind (*maybeKind)))) {
1952
2045
return failure ();
1953
2046
}
1954
2047
2048
+ auto rhsRank = rhsShapedType.getRank ();
2049
+ switch (oper) {
2050
+ case Conv:
2051
+ if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
2052
+ return failure ();
2053
+ break ;
2054
+ case Pool:
2055
+ if (rhsRank != 1 )
2056
+ return failure ();
2057
+ break ;
2058
+ }
2059
+
1955
2060
return success ();
1956
2061
}
1957
2062
@@ -3084,28 +3189,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3084
3189
}
3085
3190
3086
3191
namespace {
3087
- bool isCastOfBlockArgument (Operation *op) {
3088
- return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
3089
- isa<BlockArgument>(op->getOperand (0 ));
3090
- }
3091
-
3092
- bool isSupportedPoolKind (vector::CombiningKind kind) {
3093
- switch (kind) {
3094
- case vector::CombiningKind::ADD:
3095
- case vector::CombiningKind::MAXNUMF:
3096
- case vector::CombiningKind::MAXIMUMF:
3097
- case vector::CombiningKind::MAXSI:
3098
- case vector::CombiningKind::MAXUI:
3099
- case vector::CombiningKind::MINNUMF:
3100
- case vector::CombiningKind::MINIMUMF:
3101
- case vector::CombiningKind::MINSI:
3102
- case vector::CombiningKind::MINUI:
3103
- return true ;
3104
- default :
3105
- return false ;
3106
- }
3107
- }
3108
-
3109
3192
// / Generate a vector implementation for either:
3110
3193
// / ```
3111
3194
// / Op def: ( w, kw )
@@ -3144,53 +3227,22 @@ struct Conv1DGenerator
3144
3227
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
3145
3228
Conv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp)
3146
3229
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3147
- // Determine whether `linalgOp` can be generated with this generator
3148
- if (linalgOp.getNumDpsInputs () != 2 || linalgOp.getNumDpsInits () != 1 )
3149
- return ;
3230
+
3150
3231
lhsShaped = linalgOp.getDpsInputOperand (0 )->get ();
3151
3232
rhsShaped = linalgOp.getDpsInputOperand (1 )->get ();
3152
3233
resShaped = linalgOp.getDpsInitOperand (0 )->get ();
3153
3234
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType ());
3154
3235
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType ());
3155
3236
resShapedType = dyn_cast<ShapedType>(resShaped.getType ());
3156
- if (!lhsShapedType || !rhsShapedType || !resShapedType)
3157
- return ;
3158
- // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3159
- // (non-channeled convolution -> LHS and RHS both have single dimensions).
3160
- if ((lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 ) &&
3161
- (lhsShapedType.getRank () != 1 || resShapedType.getRank () != 1 ))
3162
- return ;
3163
3237
3164
3238
Operation *reduceOp = matchLinalgReduction (linalgOp.getDpsInitOperand (0 ));
3165
- if (!reduceOp)
3166
- return ;
3167
3239
redOp = reduceOp->getName ().getIdentifier ();
3168
3240
3169
- if (! setOperKind (reduceOp))
3170
- return ;
3241
+ setOperKind (reduceOp);
3242
+
3171
3243
auto maybeKind = getCombinerOpKind (reduceOp);
3172
- // Typically convolution will have a `Add` CombiningKind but for i1 type it
3173
- // can get strength reduced to `OR` which is also supported. This strength
3174
- // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3175
- if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3176
- *maybeKind != vector::CombiningKind::OR) &&
3177
- (oper != Pool || !isSupportedPoolKind (*maybeKind)))) {
3178
- return ;
3179
- }
3180
3244
reductionKind = maybeKind.value ();
3181
3245
3182
- auto rhsRank = rhsShapedType.getRank ();
3183
- switch (oper) {
3184
- case Conv:
3185
- if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
3186
- return ;
3187
- break ;
3188
- case Pool:
3189
- if (rhsRank != 1 )
3190
- return ;
3191
- break ;
3192
- }
3193
-
3194
3246
// The ConvolutionOpInterface gives us guarantees of existence for
3195
3247
// strides/dilations. However, we do not need to rely on those, we can
3196
3248
// simply use them if present, otherwise use the default and let the generic
@@ -3199,13 +3251,8 @@ struct Conv1DGenerator
3199
3251
auto dilations = linalgOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
3200
3252
strideW = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
3201
3253
dilationW = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
3202
-
3203
- // The op is now known to be valid.
3204
- valid = true ;
3205
3254
}
3206
3255
3207
- bool isValid () { return valid; }
3208
-
3209
3256
// / Generate a vector implementation for:
3210
3257
// / ```
3211
3258
// / Op def: ( w, kw )
@@ -3225,9 +3272,6 @@ struct Conv1DGenerator
3225
3272
// / TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3226
3273
// / > 1.
3227
3274
FailureOr<Operation *> conv (Conv1DOpOrder conv1DOpOrder) {
3228
- if (!valid)
3229
- return rewriter.notifyMatchFailure (op, " unvectorizable 1-D conv/pool" );
3230
-
3231
3275
int64_t nSize, wSize, cSize, kwSize, fSize ;
3232
3276
SmallVector<int64_t , 3 > lhsShape, rhsShape, resShape;
3233
3277
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
@@ -3510,9 +3554,6 @@ struct Conv1DGenerator
3510
3554
FailureOr<Operation *> depthwiseConv (uint64_t channelDimVecSize,
3511
3555
bool channelDimScalableFlag,
3512
3556
bool flatten) {
3513
- if (!valid)
3514
- return rewriter.notifyMatchFailure (op, " unvectorizable depthwise conv" );
3515
-
3516
3557
bool scalableChDim = false ;
3517
3558
bool useMasking = false ;
3518
3559
int64_t nSize, wSize, cSize, kwSize;
@@ -3857,8 +3898,6 @@ struct Conv1DGenerator
3857
3898
}
3858
3899
3859
3900
private:
3860
- enum OperKind { Conv, Pool };
3861
- bool valid = false ;
3862
3901
OperKind oper = Conv;
3863
3902
StringAttr redOp;
3864
3903
StringAttr poolExtOp;
@@ -3869,18 +3908,10 @@ struct Conv1DGenerator
3869
3908
vector::CombiningKind reductionKind;
3870
3909
3871
3910
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3872
- // Returns true iff it is a valid conv/pooling op.
3873
- // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3874
- // + yield) and rhs is not used) then it is the body of a pooling
3875
- // If conv, check for single `mul` predecessor. The `mul` operands must be
3876
- // block arguments or extension of block arguments.
3877
- // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3878
- // must be block arguments or extension of block arguments.
3879
- bool setOperKind (Operation *reduceOp) {
3911
+ void setOperKind (Operation *reduceOp) {
3880
3912
int numBlockArguments =
3881
3913
llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
3882
- switch (numBlockArguments) {
3883
- case 1 : {
3914
+ if (numBlockArguments == 1 ) {
3884
3915
// Will be convolution if feeder is a MulOp.
3885
3916
// A strength reduced version of MulOp for i1 type is AndOp which is also
3886
3917
// supported. Otherwise, it can be pooling. This strength reduction logic
@@ -3892,27 +3923,13 @@ struct Conv1DGenerator
3892
3923
oper = Pool;
3893
3924
isPoolExt = true ;
3894
3925
poolExtOp = feedOp->getName ().getIdentifier ();
3895
- } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3896
- (isa<arith::AndIOp>(feedOp) &&
3897
- feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
3898
- llvm::all_of (feedOp->getOperands (), [](Value v) {
3899
- if (isa<BlockArgument>(v))
3900
- return true ;
3901
- if (Operation *op = v.getDefiningOp ())
3902
- return isCastOfBlockArgument (op);
3903
- return false ;
3904
- }))) {
3905
- return false ;
3926
+ } else {
3927
+ oper = Conv;
3906
3928
}
3907
- return true ;
3908
- }
3909
- case 2 :
3910
- // Must be pooling
3929
+ } else {
3930
+ // Pooling.
3911
3931
oper = Pool;
3912
3932
isPoolExt = false ;
3913
- return true ;
3914
- default :
3915
- return false ;
3916
3933
}
3917
3934
}
3918
3935
};
@@ -3924,7 +3941,6 @@ static FailureOr<Operation *> vectorizeConvolution(
3924
3941
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t > inputVecSizes,
3925
3942
ArrayRef<bool > inputScalableVecDims, bool flatten1DDepthwiseConv) {
3926
3943
Conv1DGenerator conv1dGen (rewriter, op);
3927
- assert (conv1dGen.isValid () && " Conv1DGenerator failed" );
3928
3944
auto res = conv1dGen.generateNonChanneledConv ();
3929
3945
if (succeeded (res))
3930
3946
return res;
0 commit comments