@@ -1940,7 +1940,10 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1940
1940
}
1941
1941
1942
1942
namespace {
1943
- bool isCastOfBlockArgument (Operation *op) {
1943
+ enum class ConvOperationKind { Conv, Pool };
1944
+ } // namespace
1945
+
1946
+ static bool isCastOfBlockArgument (Operation *op) {
1944
1947
return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
1945
1948
isa<BlockArgument>(op->getOperand (0 ));
1946
1949
}
@@ -1952,8 +1955,8 @@ bool isCastOfBlockArgument(Operation *op) {
1952
1955
// block arguments or extension of block arguments.
1953
1956
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
1954
1957
// must be block arguments or extension of block arguments.
1955
- enum OperKind { Conv, Pool };
1956
- bool getOperKind (Operation *reduceOp, OperKind &oper ) {
1958
+ static std::optional<ConvOperationKind>
1959
+ getConvOperationKind (Operation *reduceOp) {
1957
1960
int numBlockArguments =
1958
1961
llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
1959
1962
@@ -1966,33 +1969,34 @@ bool getOperKind(Operation *reduceOp, OperKind &oper) {
1966
1969
auto feedValIt = llvm::find_if_not (reduceOp->getOperands (),
1967
1970
llvm::IsaPred<BlockArgument>);
1968
1971
Operation *feedOp = (*feedValIt).getDefiningOp ();
1969
- // llvm::outs() << "feedOp: " << *feedOp << "\n";
1970
1972
if (isCastOfBlockArgument (feedOp)) {
1971
- oper = Pool;
1972
- } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1973
- (isa<arith::AndIOp>(feedOp) &&
1974
- feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
1975
- llvm::all_of (feedOp->getOperands (), [](Value v) {
1976
- if (isa<BlockArgument>(v))
1977
- return true ;
1978
- if (Operation *op = v.getDefiningOp ())
1979
- return isCastOfBlockArgument (op);
1980
- return false ;
1981
- }))) {
1982
- return false ;
1973
+ return ConvOperationKind::Pool;
1983
1974
}
1984
- return true ;
1975
+
1976
+ if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1977
+ (isa<arith::AndIOp>(feedOp) &&
1978
+ feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
1979
+ llvm::all_of (feedOp->getOperands (), [](Value v) {
1980
+ if (isa<BlockArgument>(v))
1981
+ return true ;
1982
+ if (Operation *op = v.getDefiningOp ())
1983
+ return isCastOfBlockArgument (op);
1984
+ return false ;
1985
+ }))) {
1986
+ return std::nullopt;
1987
+ }
1988
+
1989
+ return ConvOperationKind::Conv;
1985
1990
}
1986
1991
case 2 :
1987
1992
// Must be pooling
1988
- oper = Pool;
1989
- return true ;
1993
+ return ConvOperationKind::Pool;
1990
1994
default :
1991
- return false ;
1995
+ return std::nullopt ;
1992
1996
}
1993
1997
}
1994
1998
1995
- bool isSupportedPoolKind (vector::CombiningKind kind) {
1999
+ static bool isSupportedPoolKind (vector::CombiningKind kind) {
1996
2000
switch (kind) {
1997
2001
case vector::CombiningKind::ADD:
1998
2002
case vector::CombiningKind::MAXNUMF:
@@ -2008,7 +2012,6 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
2008
2012
return false ;
2009
2013
}
2010
2014
}
2011
- } // namespace
2012
2015
2013
2016
static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
2014
2017
if (convOp.getNumDpsInputs () != 2 || convOp.getNumDpsInits () != 1 )
@@ -2032,29 +2035,28 @@ static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2032
2035
if (!reduceOp)
2033
2036
return failure ();
2034
2037
2035
- OperKind oper = Conv ;
2036
- if (!getOperKind (reduceOp, oper ))
2038
+ auto maybeOper = getConvOperationKind (reduceOp) ;
2039
+ if (!maybeOper. has_value ( ))
2037
2040
return failure ();
2041
+
2038
2042
auto maybeKind = getCombinerOpKind (reduceOp);
2039
2043
// Typically convolution will have a `Add` CombiningKind but for i1 type it
2040
2044
// can get strength reduced to `OR` which is also supported. This strength
2041
2045
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2042
2046
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2043
2047
*maybeKind != vector::CombiningKind::OR) &&
2044
- (oper != Pool || !isSupportedPoolKind (*maybeKind)))) {
2048
+ (*maybeOper != ConvOperationKind::Pool ||
2049
+ !isSupportedPoolKind (*maybeKind)))) {
2045
2050
return failure ();
2046
2051
}
2047
2052
2048
2053
auto rhsRank = rhsShapedType.getRank ();
2049
- switch (oper) {
2050
- case Conv:
2051
- if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
2052
- return failure ();
2053
- break ;
2054
- case Pool:
2054
+ if (*maybeOper == ConvOperationKind::Pool) {
2055
2055
if (rhsRank != 1 )
2056
2056
return failure ();
2057
- break ;
2057
+ } else {
2058
+ if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
2059
+ return failure ();
2058
2060
}
2059
2061
2060
2062
return success ();
@@ -3238,7 +3240,7 @@ struct Conv1DGenerator
3238
3240
Operation *reduceOp = matchLinalgReduction (linalgOp.getDpsInitOperand (0 ));
3239
3241
redOp = reduceOp->getName ().getIdentifier ();
3240
3242
3241
- setOperKind (reduceOp);
3243
+ setConvOperationKind (reduceOp);
3242
3244
3243
3245
auto maybeKind = getCombinerOpKind (reduceOp);
3244
3246
reductionKind = maybeKind.value ();
@@ -3293,11 +3295,11 @@ struct Conv1DGenerator
3293
3295
// out{n, w, f}
3294
3296
bindShapeDims (resShapedType, nSize, wSize, fSize );
3295
3297
switch (oper) {
3296
- case Conv:
3298
+ case ConvOperationKind:: Conv:
3297
3299
// kernel{kw, c, f}
3298
3300
bindShapeDims (rhsShapedType, kwSize, cSize);
3299
3301
break ;
3300
- case Pool:
3302
+ case ConvOperationKind:: Pool:
3301
3303
// kernel{kw}
3302
3304
bindShapeDims (rhsShapedType, kwSize);
3303
3305
cSize = fSize ;
@@ -3311,10 +3313,10 @@ struct Conv1DGenerator
3311
3313
1 ,
3312
3314
cSize};
3313
3315
switch (oper) {
3314
- case Conv:
3316
+ case ConvOperationKind:: Conv:
3315
3317
rhsShape = {kwSize, cSize, fSize };
3316
3318
break ;
3317
- case Pool:
3319
+ case ConvOperationKind:: Pool:
3318
3320
rhsShape = {kwSize};
3319
3321
break ;
3320
3322
}
@@ -3324,11 +3326,11 @@ struct Conv1DGenerator
3324
3326
// out{n, f, w}
3325
3327
bindShapeDims (resShapedType, nSize, fSize , wSize);
3326
3328
switch (oper) {
3327
- case Conv:
3329
+ case ConvOperationKind:: Conv:
3328
3330
// kernel{f, c, kw}
3329
3331
bindShapeDims (rhsShapedType, fSize , cSize, kwSize);
3330
3332
break ;
3331
- case Pool:
3333
+ case ConvOperationKind:: Pool:
3332
3334
// kernel{kw}
3333
3335
bindShapeDims (rhsShapedType, kwSize);
3334
3336
cSize = fSize ;
@@ -3341,10 +3343,10 @@ struct Conv1DGenerator
3341
3343
((wSize - 1 ) * strideW + 1 ) + ((kwSize - 1 ) * dilationW + 1 ) -
3342
3344
1 };
3343
3345
switch (oper) {
3344
- case Conv:
3346
+ case ConvOperationKind:: Conv:
3345
3347
rhsShape = {fSize , cSize, kwSize};
3346
3348
break ;
3347
- case Pool:
3349
+ case ConvOperationKind:: Pool:
3348
3350
rhsShape = {kwSize};
3349
3351
break ;
3350
3352
}
@@ -3376,7 +3378,7 @@ struct Conv1DGenerator
3376
3378
lhsPadding);
3377
3379
// This is needed only for Conv.
3378
3380
Value rhs = nullptr ;
3379
- if (oper == Conv)
3381
+ if (oper == ConvOperationKind:: Conv)
3380
3382
rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
3381
3383
rhsPadding);
3382
3384
Value res = rewriter.create <vector::TransferReadOp>(loc, resType, resShaped,
@@ -3399,7 +3401,7 @@ struct Conv1DGenerator
3399
3401
static constexpr std::array<int64_t , 3 > permRhs = {2 , 1 , 0 };
3400
3402
3401
3403
// This is needed only for Conv.
3402
- if (oper == Conv)
3404
+ if (oper == ConvOperationKind:: Conv)
3403
3405
rhs = rewriter.create <vector::TransposeOp>(loc, rhs, permRhs);
3404
3406
// nfw -> nwf
3405
3407
static constexpr std::array<int64_t , 3 > permRes = {0 , 2 , 1 };
@@ -3417,7 +3419,7 @@ struct Conv1DGenerator
3417
3419
kwSize, strideW, dilationW, wSizeStep,
3418
3420
isSingleChanneled);
3419
3421
// Do not do for pooling.
3420
- if (oper == Conv)
3422
+ if (oper == ConvOperationKind:: Conv)
3421
3423
rhsVals = extractConvFilterSlices (rewriter, loc, rhs, kwSize);
3422
3424
resVals = extractConvResultSlices (rewriter, loc, res, nSize, wSize, fSize ,
3423
3425
wSizeStep, isSingleChanneled);
@@ -3432,7 +3434,7 @@ struct Conv1DGenerator
3432
3434
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
3433
3435
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
3434
3436
switch (oper) {
3435
- case Conv:
3437
+ case ConvOperationKind:: Conv:
3436
3438
if (isSingleChanneled) {
3437
3439
resVals[w] = conv1dSliceAsOuterProduct (rewriter, loc,
3438
3440
lhsVals[linearIndex (kw, w)],
@@ -3443,7 +3445,7 @@ struct Conv1DGenerator
3443
3445
rhsVals[kw], resVals[w]);
3444
3446
}
3445
3447
break ;
3446
- case Pool:
3448
+ case ConvOperationKind:: Pool:
3447
3449
resVals[w] = pool1dSlice (rewriter, loc, lhsVals[linearIndex (kw, w)],
3448
3450
resVals[w]);
3449
3451
break ;
@@ -3898,7 +3900,7 @@ struct Conv1DGenerator
3898
3900
}
3899
3901
3900
3902
private:
3901
- OperKind oper = Conv;
3903
+ ConvOperationKind oper = ConvOperationKind:: Conv;
3902
3904
StringAttr redOp;
3903
3905
StringAttr poolExtOp;
3904
3906
bool isPoolExt = false ;
@@ -3908,7 +3910,7 @@ struct Conv1DGenerator
3908
3910
vector::CombiningKind reductionKind;
3909
3911
3910
3912
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3911
- void setOperKind (Operation *reduceOp) {
3913
+ void setConvOperationKind (Operation *reduceOp) {
3912
3914
int numBlockArguments =
3913
3915
llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
3914
3916
if (numBlockArguments == 1 ) {
@@ -3920,17 +3922,17 @@ struct Conv1DGenerator
3920
3922
llvm::IsaPred<BlockArgument>);
3921
3923
Operation *feedOp = (*feedValIt).getDefiningOp ();
3922
3924
if (isCastOfBlockArgument (feedOp)) {
3923
- oper = Pool;
3925
+ oper = ConvOperationKind:: Pool;
3924
3926
isPoolExt = true ;
3925
3927
poolExtOp = feedOp->getName ().getIdentifier ();
3926
- } else {
3927
- oper = Conv;
3928
+ return ;
3928
3929
}
3929
- } else {
3930
- // Pooling.
3931
- oper = Pool;
3932
- isPoolExt = false ;
3930
+ oper = ConvOperationKind::Conv;
3931
+ return ;
3933
3932
}
3933
+ // numBlockArugments == 2 and this is a pooling op.
3934
+ oper = ConvOperationKind::Pool;
3935
+ isPoolExt = false ;
3934
3936
}
3935
3937
};
3936
3938
} // namespace
0 commit comments