Skip to content

Commit 2ef5555

Browse files
committed
Peel away conv1dgen validator into precondition check
1 parent 74a8986 commit 2ef5555

File tree

1 file changed

+121
-105
lines changed

1 file changed

+121
-105
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 121 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,19 +1939,124 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19391939
return success();
19401940
}
19411941

1942+
namespace {
1943+
bool isCastOfBlockArgument(Operation *op) {
1944+
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
1945+
isa<BlockArgument>(op->getOperand(0));
1946+
}
1947+
1948+
// Returns true iff it is a valid conv/pooling op.
1949+
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
1950+
// + yield) and rhs is not used) then it is the body of a pooling
1951+
// If conv, check for single `mul` predecessor. The `mul` operands must be
1952+
// block arguments or extension of block arguments.
1953+
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
1954+
// must be block arguments or extension of block arguments.
1955+
enum OperKind { Conv, Pool };
1956+
bool getOperKind(Operation *reduceOp, OperKind &oper) {
1957+
int numBlockArguments =
1958+
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
1959+
1960+
switch (numBlockArguments) {
1961+
case 1: {
1962+
// Will be convolution if feeder is a MulOp.
1963+
// A strength reduced version of MulOp for i1 type is AndOp which is also
1964+
// supported. Otherwise, it can be pooling. This strength reduction logic
1965+
// is in `buildBinaryFn` helper in the Linalg dialect.
1966+
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
1967+
llvm::IsaPred<BlockArgument>);
1968+
Operation *feedOp = (*feedValIt).getDefiningOp();
1969+
// llvm::outs() << "feedOp: " << *feedOp << "\n";
1970+
if (isCastOfBlockArgument(feedOp)) {
1971+
oper = Pool;
1972+
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1973+
(isa<arith::AndIOp>(feedOp) &&
1974+
feedOp->getResultTypes()[0].isInteger(1))) &&
1975+
llvm::all_of(feedOp->getOperands(), [](Value v) {
1976+
if (isa<BlockArgument>(v))
1977+
return true;
1978+
if (Operation *op = v.getDefiningOp())
1979+
return isCastOfBlockArgument(op);
1980+
return false;
1981+
}))) {
1982+
return false;
1983+
}
1984+
return true;
1985+
}
1986+
case 2:
1987+
// Must be pooling
1988+
oper = Pool;
1989+
return true;
1990+
default:
1991+
return false;
1992+
}
1993+
}
1994+
1995+
bool isSupportedPoolKind(vector::CombiningKind kind) {
1996+
switch (kind) {
1997+
case vector::CombiningKind::ADD:
1998+
case vector::CombiningKind::MAXNUMF:
1999+
case vector::CombiningKind::MAXIMUMF:
2000+
case vector::CombiningKind::MAXSI:
2001+
case vector::CombiningKind::MAXUI:
2002+
case vector::CombiningKind::MINNUMF:
2003+
case vector::CombiningKind::MINIMUMF:
2004+
case vector::CombiningKind::MINSI:
2005+
case vector::CombiningKind::MINUI:
2006+
return true;
2007+
default:
2008+
return false;
2009+
}
2010+
}
2011+
} // namespace
2012+
19422013
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
1943-
// We only support 1D convolutions, reject all other cases.
1944-
if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
1945-
linalg::Conv2DNchwFchwOp>(convOp)) {
1946-
LDBG("2D convolutions are not supported\n");
2014+
if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
2015+
return failure();
2016+
2017+
auto lhsShaped = convOp.getDpsInputOperand(0)->get();
2018+
auto rhsShaped = convOp.getDpsInputOperand(1)->get();
2019+
auto resShaped = convOp.getDpsInitOperand(0)->get();
2020+
auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2021+
auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2022+
auto resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2023+
if (!lhsShapedType || !rhsShapedType || !resShapedType)
2024+
return failure();
2025+
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2026+
// (non-channeled convolution -> LHS and RHS both have single dimensions).
2027+
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2028+
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
19472029
return failure();
1948-
}
19492030

1950-
if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
1951-
LDBG("3D convolutions are not supported\n");
2031+
Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2032+
if (!reduceOp)
2033+
return failure();
2034+
2035+
OperKind oper = Conv;
2036+
if (!getOperKind(reduceOp, oper))
2037+
return failure();
2038+
auto maybeKind = getCombinerOpKind(reduceOp);
2039+
// Typically convolution will have a `Add` CombiningKind but for i1 type it
2040+
// can get strength reduced to `OR` which is also supported. This strength
2041+
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2042+
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2043+
*maybeKind != vector::CombiningKind::OR) &&
2044+
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
19522045
return failure();
19532046
}
19542047

2048+
auto rhsRank = rhsShapedType.getRank();
2049+
switch (oper) {
2050+
case Conv:
2051+
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2052+
return failure();
2053+
break;
2054+
case Pool:
2055+
if (rhsRank != 1)
2056+
return failure();
2057+
break;
2058+
}
2059+
19552060
return success();
19562061
}
19572062

@@ -3084,28 +3189,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
30843189
}
30853190

30863191
namespace {
3087-
bool isCastOfBlockArgument(Operation *op) {
3088-
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
3089-
isa<BlockArgument>(op->getOperand(0));
3090-
}
3091-
3092-
bool isSupportedPoolKind(vector::CombiningKind kind) {
3093-
switch (kind) {
3094-
case vector::CombiningKind::ADD:
3095-
case vector::CombiningKind::MAXNUMF:
3096-
case vector::CombiningKind::MAXIMUMF:
3097-
case vector::CombiningKind::MAXSI:
3098-
case vector::CombiningKind::MAXUI:
3099-
case vector::CombiningKind::MINNUMF:
3100-
case vector::CombiningKind::MINIMUMF:
3101-
case vector::CombiningKind::MINSI:
3102-
case vector::CombiningKind::MINUI:
3103-
return true;
3104-
default:
3105-
return false;
3106-
}
3107-
}
3108-
31093192
/// Generate a vector implementation for either:
31103193
/// ```
31113194
/// Op def: ( w, kw )
@@ -3144,53 +3227,22 @@ struct Conv1DGenerator
31443227
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
31453228
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
31463229
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3147-
// Determine whether `linalgOp` can be generated with this generator
3148-
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3149-
return;
3230+
31503231
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
31513232
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
31523233
resShaped = linalgOp.getDpsInitOperand(0)->get();
31533234
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
31543235
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
31553236
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3156-
if (!lhsShapedType || !rhsShapedType || !resShapedType)
3157-
return;
3158-
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3159-
// (non-channeled convolution -> LHS and RHS both have single dimensions).
3160-
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3161-
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3162-
return;
31633237

31643238
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3165-
if (!reduceOp)
3166-
return;
31673239
redOp = reduceOp->getName().getIdentifier();
31683240

3169-
if (!setOperKind(reduceOp))
3170-
return;
3241+
setOperKind(reduceOp);
3242+
31713243
auto maybeKind = getCombinerOpKind(reduceOp);
3172-
// Typically convolution will have a `Add` CombiningKind but for i1 type it
3173-
// can get strength reduced to `OR` which is also supported. This strength
3174-
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3175-
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3176-
*maybeKind != vector::CombiningKind::OR) &&
3177-
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3178-
return;
3179-
}
31803244
reductionKind = maybeKind.value();
31813245

3182-
auto rhsRank = rhsShapedType.getRank();
3183-
switch (oper) {
3184-
case Conv:
3185-
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3186-
return;
3187-
break;
3188-
case Pool:
3189-
if (rhsRank != 1)
3190-
return;
3191-
break;
3192-
}
3193-
31943246
// The ConvolutionOpInterface gives us guarantees of existence for
31953247
// strides/dilations. However, we do not need to rely on those, we can
31963248
// simply use them if present, otherwise use the default and let the generic
@@ -3199,13 +3251,8 @@ struct Conv1DGenerator
31993251
auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
32003252
strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
32013253
dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3202-
3203-
// The op is now known to be valid.
3204-
valid = true;
32053254
}
32063255

3207-
bool isValid() { return valid; }
3208-
32093256
/// Generate a vector implementation for:
32103257
/// ```
32113258
/// Op def: ( w, kw )
@@ -3225,9 +3272,6 @@ struct Conv1DGenerator
32253272
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
32263273
/// > 1.
32273274
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3228-
if (!valid)
3229-
return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3230-
32313275
int64_t nSize, wSize, cSize, kwSize, fSize;
32323276
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
32333277
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
@@ -3510,9 +3554,6 @@ struct Conv1DGenerator
35103554
FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
35113555
bool channelDimScalableFlag,
35123556
bool flatten) {
3513-
if (!valid)
3514-
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3515-
35163557
bool scalableChDim = false;
35173558
bool useMasking = false;
35183559
int64_t nSize, wSize, cSize, kwSize;
@@ -3857,8 +3898,6 @@ struct Conv1DGenerator
38573898
}
38583899

38593900
private:
3860-
enum OperKind { Conv, Pool };
3861-
bool valid = false;
38623901
OperKind oper = Conv;
38633902
StringAttr redOp;
38643903
StringAttr poolExtOp;
@@ -3869,18 +3908,10 @@ struct Conv1DGenerator
38693908
vector::CombiningKind reductionKind;
38703909

38713910
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3872-
// Returns true iff it is a valid conv/pooling op.
3873-
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3874-
// + yield) and rhs is not used) then it is the body of a pooling
3875-
// If conv, check for single `mul` predecessor. The `mul` operands must be
3876-
// block arguments or extension of block arguments.
3877-
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3878-
// must be block arguments or extension of block arguments.
3879-
bool setOperKind(Operation *reduceOp) {
3911+
void setOperKind(Operation *reduceOp) {
38803912
int numBlockArguments =
38813913
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3882-
switch (numBlockArguments) {
3883-
case 1: {
3914+
if (numBlockArguments == 1) {
38843915
// Will be convolution if feeder is a MulOp.
38853916
// A strength reduced version of MulOp for i1 type is AndOp which is also
38863917
// supported. Otherwise, it can be pooling. This strength reduction logic
@@ -3892,27 +3923,13 @@ struct Conv1DGenerator
38923923
oper = Pool;
38933924
isPoolExt = true;
38943925
poolExtOp = feedOp->getName().getIdentifier();
3895-
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3896-
(isa<arith::AndIOp>(feedOp) &&
3897-
feedOp->getResultTypes()[0].isInteger(1))) &&
3898-
llvm::all_of(feedOp->getOperands(), [](Value v) {
3899-
if (isa<BlockArgument>(v))
3900-
return true;
3901-
if (Operation *op = v.getDefiningOp())
3902-
return isCastOfBlockArgument(op);
3903-
return false;
3904-
}))) {
3905-
return false;
3926+
} else {
3927+
oper = Conv;
39063928
}
3907-
return true;
3908-
}
3909-
case 2:
3910-
// Must be pooling
3929+
} else {
3930+
// Pooling.
39113931
oper = Pool;
39123932
isPoolExt = false;
3913-
return true;
3914-
default:
3915-
return false;
39163933
}
39173934
}
39183935
};
@@ -3924,7 +3941,6 @@ static FailureOr<Operation *> vectorizeConvolution(
39243941
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
39253942
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
39263943
Conv1DGenerator conv1dGen(rewriter, op);
3927-
assert(conv1dGen.isValid() && "Conv1DGenerator failed");
39283944
auto res = conv1dGen.generateNonChanneledConv();
39293945
if (succeeded(res))
39303946
return res;

0 commit comments

Comments
 (0)