Skip to content

[MLIR] Refactor to create vectorization convOp precondition check #130181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 150 additions & 114 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,127 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}

namespace {
bool isCastOfBlockArgument(Operation *op) {
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
isa<BlockArgument>(op->getOperand(0));
}

// Returns true iff it is a valid conv/pooling op.
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
// + yield) and rhs is not used) then it is the body of a pooling
// If conv, check for single `mul` predecessor. The `mul` operands must be
// block arguments or extension of block arguments.
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
enum OperKind { Conv, Pool };
bool getOperKind(Operation *reduceOp, OperKind &oper) {
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);

switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// A strength reduced version of MulOp for i1 type is AndOp which is also
// supported. Otherwise, it can be pooling. This strength reduction logic
// is in `buildBinaryFn` helper in the Linalg dialect.
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
// llvm::outs() << "feedOp: " << *feedOp << "\n";
if (isCastOfBlockArgument(feedOp)) {
oper = Pool;
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
(isa<arith::AndIOp>(feedOp) &&
feedOp->getResultTypes()[0].isInteger(1))) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
if (Operation *op = v.getDefiningOp())
return isCastOfBlockArgument(op);
return false;
}))) {
return false;
}
return true;
}
case 2:
// Must be pooling
oper = Pool;
return true;
default:
return false;
}
}

bool isSupportedPoolKind(vector::CombiningKind kind) {
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::MAXNUMF:
case vector::CombiningKind::MAXIMUMF:
case vector::CombiningKind::MAXSI:
case vector::CombiningKind::MAXUI:
case vector::CombiningKind::MINNUMF:
case vector::CombiningKind::MINIMUMF:
case vector::CombiningKind::MINSI:
case vector::CombiningKind::MINUI:
return true;
default:
return false;
}
}
} // namespace

static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
return failure();

auto lhsShaped = convOp.getDpsInputOperand(0)->get();
auto rhsShaped = convOp.getDpsInputOperand(1)->get();
auto resShaped = convOp.getDpsInitOperand(0)->get();
auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
auto resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return failure();
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
// (non-channeled convolution -> LHS and RHS both have single dimensions).
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return failure();

Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
if (!reduceOp)
return failure();

OperKind oper = Conv;
if (!getOperKind(reduceOp, oper))
return failure();
auto maybeKind = getCombinerOpKind(reduceOp);
// Typically convolution will have a `Add` CombiningKind but for i1 type it
// can get strength reduced to `OR` which is also supported. This strength
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
*maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return failure();
}

auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
return failure();
break;
case Pool:
if (rhsRank != 1)
return failure();
break;
}

return success();
}

static LogicalResult vectorizeLinalgOpPrecondition(
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
Expand Down Expand Up @@ -1991,7 +2112,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
return vectorizeConvOpPrecondition(linalgOp);

// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
Expand Down Expand Up @@ -3067,28 +3189,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
}

namespace {
bool isCastOfBlockArgument(Operation *op) {
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
isa<BlockArgument>(op->getOperand(0));
}

bool isSupportedPoolKind(vector::CombiningKind kind) {
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::MAXNUMF:
case vector::CombiningKind::MAXIMUMF:
case vector::CombiningKind::MAXSI:
case vector::CombiningKind::MAXUI:
case vector::CombiningKind::MINNUMF:
case vector::CombiningKind::MINIMUMF:
case vector::CombiningKind::MINSI:
case vector::CombiningKind::MINUI:
return true;
default:
return false;
}
}

/// Generate a vector implementation for either:
/// ```
/// Op def: ( w, kw )
Expand Down Expand Up @@ -3125,58 +3225,32 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
/// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
int dilationW)
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
strideW(strideW), dilationW(dilationW) {
// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {

lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
// (non-channeled convolution -> LHS and RHS both have single dimensions).
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return;

Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
if (!reduceOp)
return;
redOp = reduceOp->getName().getIdentifier();

if (!setOperKind(reduceOp))
return;
setOperKind(reduceOp);

auto maybeKind = getCombinerOpKind(reduceOp);
// Typically convolution will have a `Add` CombiningKind but for i1 type it
// can get strength reduced to `OR` which is also supported. This strength
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
*maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
reductionKind = maybeKind.value();

auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
return;
break;
case Pool:
if (rhsRank != 1)
return;
break;
}
// The op is now known to be valid.
valid = true;
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
// conv. matcher in the ConvGenerator succeed or fail.
auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
}

/// Generate a vector implementation for:
Expand All @@ -3198,9 +3272,6 @@ struct Conv1DGenerator
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");

int64_t nSize, wSize, cSize, kwSize, fSize;
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
Expand Down Expand Up @@ -3483,9 +3554,6 @@ struct Conv1DGenerator
FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
bool channelDimScalableFlag,
bool flatten) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");

bool scalableChDim = false;
bool useMasking = false;
int64_t nSize, wSize, cSize, kwSize;
Expand Down Expand Up @@ -3830,8 +3898,6 @@ struct Conv1DGenerator
}

private:
enum OperKind { Conv, Pool };
bool valid = false;
OperKind oper = Conv;
StringAttr redOp;
StringAttr poolExtOp;
Expand All @@ -3842,18 +3908,10 @@ struct Conv1DGenerator
vector::CombiningKind reductionKind;

// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
// Returns true iff it is a valid conv/pooling op.
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
// + yield) and rhs is not used) then it is the body of a pooling
// If conv, check for single `mul` predecessor. The `mul` operands must be
// block arguments or extension of block arguments.
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
bool setOperKind(Operation *reduceOp) {
void setOperKind(Operation *reduceOp) {
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
switch (numBlockArguments) {
case 1: {
if (numBlockArguments == 1) {
// Will be convolution if feeder is a MulOp.
// A strength reduced version of MulOp for i1 type is AndOp which is also
// supported. Otherwise, it can be pooling. This strength reduction logic
Expand All @@ -3865,27 +3923,13 @@ struct Conv1DGenerator
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
} else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
(isa<arith::AndIOp>(feedOp) &&
feedOp->getResultTypes()[0].isInteger(1))) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
if (Operation *op = v.getDefiningOp())
return isCastOfBlockArgument(op);
return false;
}))) {
return false;
} else {
oper = Conv;
}
return true;
}
case 2:
// Must be pooling
} else {
// Pooling.
oper = Pool;
isPoolExt = false;
return true;
default:
return false;
}
}
};
Expand All @@ -3896,28 +3940,20 @@ struct Conv1DGenerator
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
// conv. matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
Conv1DGenerator e(rewriter, op, stride, dilation);
auto res = e.generateNonChanneledConv();
Conv1DGenerator conv1dGen(rewriter, op);
auto res = conv1dGen.generateNonChanneledConv();
if (succeeded(res))
return res;
res = e.generateNwcConv();
res = conv1dGen.generateNwcConv();
if (succeeded(res))
return res;
res = e.generateNcwConv();
res = conv1dGen.generateNcwConv();
if (succeeded(res))
return res;
res = e.generateNwcPooling();
res = conv1dGen.generateNwcPooling();
if (succeeded(res))
return res;
res = e.generateNcwPooling();
res = conv1dGen.generateNcwPooling();
if (succeeded(res))
return res;

Expand All @@ -3940,8 +3976,8 @@ static FailureOr<Operation *> vectorizeConvolution(
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
}
return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
flatten1DDepthwiseConv);
return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
flatten1DDepthwiseConv);
}

struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
Expand Down
Loading