Skip to content

Commit 13f5183

Browse files
committed
Addressing review feedbacks
1 parent 2ef5555 commit 13f5183

File tree

1 file changed

+57
-55
lines changed

1 file changed

+57
-55
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,7 +1940,10 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19401940
}
19411941

19421942
namespace {
1943-
bool isCastOfBlockArgument(Operation *op) {
1943+
enum class ConvOperationKind { Conv, Pool };
1944+
} // namespace
1945+
1946+
static bool isCastOfBlockArgument(Operation *op) {
19441947
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
19451948
isa<BlockArgument>(op->getOperand(0));
19461949
}
@@ -1952,8 +1955,8 @@ bool isCastOfBlockArgument(Operation *op) {
19521955
// block arguments or extension of block arguments.
19531956
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
19541957
// must be block arguments or extension of block arguments.
1955-
enum OperKind { Conv, Pool };
1956-
bool getOperKind(Operation *reduceOp, OperKind &oper) {
1958+
static std::optional<ConvOperationKind>
1959+
getConvOperationKind(Operation *reduceOp) {
19571960
int numBlockArguments =
19581961
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
19591962

@@ -1966,33 +1969,34 @@ bool getOperKind(Operation *reduceOp, OperKind &oper) {
19661969
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
19671970
llvm::IsaPred<BlockArgument>);
19681971
Operation *feedOp = (*feedValIt).getDefiningOp();
1969-
// llvm::outs() << "feedOp: " << *feedOp << "\n";
19701972
if (isCastOfBlockArgument(feedOp)) {
1971-
oper = Pool;
1972-
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1973-
(isa<arith::AndIOp>(feedOp) &&
1974-
feedOp->getResultTypes()[0].isInteger(1))) &&
1975-
llvm::all_of(feedOp->getOperands(), [](Value v) {
1976-
if (isa<BlockArgument>(v))
1977-
return true;
1978-
if (Operation *op = v.getDefiningOp())
1979-
return isCastOfBlockArgument(op);
1980-
return false;
1981-
}))) {
1982-
return false;
1973+
return ConvOperationKind::Pool;
19831974
}
1984-
return true;
1975+
1976+
if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1977+
(isa<arith::AndIOp>(feedOp) &&
1978+
feedOp->getResultTypes()[0].isInteger(1))) &&
1979+
llvm::all_of(feedOp->getOperands(), [](Value v) {
1980+
if (isa<BlockArgument>(v))
1981+
return true;
1982+
if (Operation *op = v.getDefiningOp())
1983+
return isCastOfBlockArgument(op);
1984+
return false;
1985+
}))) {
1986+
return std::nullopt;
1987+
}
1988+
1989+
return ConvOperationKind::Conv;
19851990
}
19861991
case 2:
19871992
// Must be pooling
1988-
oper = Pool;
1989-
return true;
1993+
return ConvOperationKind::Pool;
19901994
default:
1991-
return false;
1995+
return std::nullopt;
19921996
}
19931997
}
19941998

1995-
bool isSupportedPoolKind(vector::CombiningKind kind) {
1999+
static bool isSupportedPoolKind(vector::CombiningKind kind) {
19962000
switch (kind) {
19972001
case vector::CombiningKind::ADD:
19982002
case vector::CombiningKind::MAXNUMF:
@@ -2008,7 +2012,6 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
20082012
return false;
20092013
}
20102014
}
2011-
} // namespace
20122015

20132016
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
20142017
if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
@@ -2032,29 +2035,28 @@ static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
20322035
if (!reduceOp)
20332036
return failure();
20342037

2035-
OperKind oper = Conv;
2036-
if (!getOperKind(reduceOp, oper))
2038+
auto maybeOper = getConvOperationKind(reduceOp);
2039+
if (!maybeOper.has_value())
20372040
return failure();
2041+
20382042
auto maybeKind = getCombinerOpKind(reduceOp);
20392043
// Typically convolution will have a `Add` CombiningKind but for i1 type it
20402044
// can get strength reduced to `OR` which is also supported. This strength
20412045
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
20422046
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
20432047
*maybeKind != vector::CombiningKind::OR) &&
2044-
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2048+
(*maybeOper != ConvOperationKind::Pool ||
2049+
!isSupportedPoolKind(*maybeKind)))) {
20452050
return failure();
20462051
}
20472052

20482053
auto rhsRank = rhsShapedType.getRank();
2049-
switch (oper) {
2050-
case Conv:
2051-
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2052-
return failure();
2053-
break;
2054-
case Pool:
2054+
if (*maybeOper == ConvOperationKind::Pool) {
20552055
if (rhsRank != 1)
20562056
return failure();
2057-
break;
2057+
} else {
2058+
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2059+
return failure();
20582060
}
20592061

20602062
return success();
@@ -3238,7 +3240,7 @@ struct Conv1DGenerator
32383240
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
32393241
redOp = reduceOp->getName().getIdentifier();
32403242

3241-
setOperKind(reduceOp);
3243+
setConvOperationKind(reduceOp);
32423244

32433245
auto maybeKind = getCombinerOpKind(reduceOp);
32443246
reductionKind = maybeKind.value();
@@ -3293,11 +3295,11 @@ struct Conv1DGenerator
32933295
// out{n, w, f}
32943296
bindShapeDims(resShapedType, nSize, wSize, fSize);
32953297
switch (oper) {
3296-
case Conv:
3298+
case ConvOperationKind::Conv:
32973299
// kernel{kw, c, f}
32983300
bindShapeDims(rhsShapedType, kwSize, cSize);
32993301
break;
3300-
case Pool:
3302+
case ConvOperationKind::Pool:
33013303
// kernel{kw}
33023304
bindShapeDims(rhsShapedType, kwSize);
33033305
cSize = fSize;
@@ -3311,10 +3313,10 @@ struct Conv1DGenerator
33113313
1,
33123314
cSize};
33133315
switch (oper) {
3314-
case Conv:
3316+
case ConvOperationKind::Conv:
33153317
rhsShape = {kwSize, cSize, fSize};
33163318
break;
3317-
case Pool:
3319+
case ConvOperationKind::Pool:
33183320
rhsShape = {kwSize};
33193321
break;
33203322
}
@@ -3324,11 +3326,11 @@ struct Conv1DGenerator
33243326
// out{n, f, w}
33253327
bindShapeDims(resShapedType, nSize, fSize, wSize);
33263328
switch (oper) {
3327-
case Conv:
3329+
case ConvOperationKind::Conv:
33283330
// kernel{f, c, kw}
33293331
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
33303332
break;
3331-
case Pool:
3333+
case ConvOperationKind::Pool:
33323334
// kernel{kw}
33333335
bindShapeDims(rhsShapedType, kwSize);
33343336
cSize = fSize;
@@ -3341,10 +3343,10 @@ struct Conv1DGenerator
33413343
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
33423344
1};
33433345
switch (oper) {
3344-
case Conv:
3346+
case ConvOperationKind::Conv:
33453347
rhsShape = {fSize, cSize, kwSize};
33463348
break;
3347-
case Pool:
3349+
case ConvOperationKind::Pool:
33483350
rhsShape = {kwSize};
33493351
break;
33503352
}
@@ -3376,7 +3378,7 @@ struct Conv1DGenerator
33763378
lhsPadding);
33773379
// This is needed only for Conv.
33783380
Value rhs = nullptr;
3379-
if (oper == Conv)
3381+
if (oper == ConvOperationKind::Conv)
33803382
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
33813383
rhsPadding);
33823384
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
@@ -3399,7 +3401,7 @@ struct Conv1DGenerator
33993401
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
34003402

34013403
// This is needed only for Conv.
3402-
if (oper == Conv)
3404+
if (oper == ConvOperationKind::Conv)
34033405
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
34043406
// nfw -> nwf
34053407
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
@@ -3417,7 +3419,7 @@ struct Conv1DGenerator
34173419
kwSize, strideW, dilationW, wSizeStep,
34183420
isSingleChanneled);
34193421
// Do not do for pooling.
3420-
if (oper == Conv)
3422+
if (oper == ConvOperationKind::Conv)
34213423
rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
34223424
resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
34233425
wSizeStep, isSingleChanneled);
@@ -3432,7 +3434,7 @@ struct Conv1DGenerator
34323434
for (int64_t kw = 0; kw < kwSize; ++kw) {
34333435
for (int64_t w = 0; w < wSize; w += wSizeStep) {
34343436
switch (oper) {
3435-
case Conv:
3437+
case ConvOperationKind::Conv:
34363438
if (isSingleChanneled) {
34373439
resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
34383440
lhsVals[linearIndex(kw, w)],
@@ -3443,7 +3445,7 @@ struct Conv1DGenerator
34433445
rhsVals[kw], resVals[w]);
34443446
}
34453447
break;
3446-
case Pool:
3448+
case ConvOperationKind::Pool:
34473449
resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
34483450
resVals[w]);
34493451
break;
@@ -3898,7 +3900,7 @@ struct Conv1DGenerator
38983900
}
38993901

39003902
private:
3901-
OperKind oper = Conv;
3903+
ConvOperationKind oper = ConvOperationKind::Conv;
39023904
StringAttr redOp;
39033905
StringAttr poolExtOp;
39043906
bool isPoolExt = false;
@@ -3908,7 +3910,7 @@ struct Conv1DGenerator
39083910
vector::CombiningKind reductionKind;
39093911

39103912
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3911-
void setOperKind(Operation *reduceOp) {
3913+
void setConvOperationKind(Operation *reduceOp) {
39123914
int numBlockArguments =
39133915
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
39143916
if (numBlockArguments == 1) {
@@ -3920,17 +3922,17 @@ struct Conv1DGenerator
39203922
llvm::IsaPred<BlockArgument>);
39213923
Operation *feedOp = (*feedValIt).getDefiningOp();
39223924
if (isCastOfBlockArgument(feedOp)) {
3923-
oper = Pool;
3925+
oper = ConvOperationKind::Pool;
39243926
isPoolExt = true;
39253927
poolExtOp = feedOp->getName().getIdentifier();
3926-
} else {
3927-
oper = Conv;
3928+
return;
39283929
}
3929-
} else {
3930-
// Pooling.
3931-
oper = Pool;
3932-
isPoolExt = false;
3930+
oper = ConvOperationKind::Conv;
3931+
return;
39333932
}
3933+
// numBlockArugments == 2 and this is a pooling op.
3934+
oper = ConvOperationKind::Pool;
3935+
isPoolExt = false;
39343936
}
39353937
};
39363938
} // namespace

0 commit comments

Comments
 (0)