Skip to content

Commit 4d21da0

Browse files
authored
[mlir] Return vectorized values instead of replacing (#144158)
Updates the linalg::vectorize function to return a `FailureOr<VectorizationResult>` containing the values to replace the original operation, instead of directly replacing the original operation. This aligns better with the style of transforms used with the TilingInterface, and gives more control to users over the lowering, since it allows for additional transformation of the IR before replacement. There was already a `VectorizationResult` defined, which was used for the internal vectorize implementation using `CustomVectorizationHook`s, so the old struct is renamed to `VectorizationHookResult`. Note for integration: The replacement of the original operation is now the responsibility of the caller, so wherever `linalg::vectorize` is used, the caller must also do `rewriter.replaceOp(vectorizeResults->replacements)`. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 280f60e commit 4d21da0

File tree

3 files changed

+86
-78
lines changed

3 files changed

+86
-78
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -854,17 +854,23 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
854854
/// to work (these are checked by the vectorizer itself).
855855
bool hasVectorizationImpl(Operation *);
856856

857-
/// Emit a suitable vector form for an operation. If provided,
858-
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
859-
/// must match the rank of the iteration space of the operation and the sizes
860-
/// must be smaller or equal than their counterpart interation space sizes, if
861-
/// static. `inputVectorShapes` also allows the vectorization of operations with
862-
/// dynamic shapes.
863-
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
864-
ArrayRef<int64_t> inputVectorSizes = {},
865-
ArrayRef<bool> inputScalableVecDims = {},
866-
bool vectorizeNDExtract = false,
867-
bool flatten1DDepthwiseConv = false);
857+
/// Transformation information returned after vectorizing.
858+
struct VectorizationResult {
859+
/// Results of the vectorization transform to replace the original operation.
860+
SmallVector<Value> replacements;
861+
};
862+
/// Returns a `VectorizationResult` containing the results of the vectorized op,
863+
/// or failure if the transformation fails. If provided, `inputVectorSizes` are
864+
/// used to vectorize this operation. `inputVectorSizes` must match the rank of
865+
/// the iteration space of the operation and the input vector sizes must be
866+
/// greater than or equal to their counterpart iteration space sizes, if static.
867+
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
868+
/// shapes.
869+
FailureOr<VectorizationResult>
870+
vectorize(RewriterBase &rewriter, Operation *op,
871+
ArrayRef<int64_t> inputVectorSizes = {},
872+
ArrayRef<bool> inputScalableVecDims = {},
873+
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
868874

869875
/// Emit a suitable vector form for a Copy op with fully static shape.
870876
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,9 +3823,14 @@ struct VectorizationPattern : public RewritePattern {
38233823
if (!linalg::hasVectorizationImpl(op))
38243824
return rewriter.notifyMatchFailure(op,
38253825
"Unsupported Op, cannot vectorize");
3826-
return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3827-
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3828-
flatten1DDepthwiseConv);
3826+
FailureOr<VectorizationResult> vectorResults =
3827+
vectorize(rewriter, op, /*inputVectorSizes=*/{},
3828+
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3829+
flatten1DDepthwiseConv);
3830+
if (failed(vectorResults))
3831+
return failure();
3832+
rewriter.replaceOp(op, vectorResults->replacements);
3833+
return success();
38293834
}
38303835

38313836
private:
@@ -3914,13 +3919,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39143919
return mlir::emitSilenceableFailure(target->getLoc())
39153920
<< "Unsupported Op, cannot vectorize";
39163921
}
3917-
3918-
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3919-
getScalableSizes(),
3920-
getVectorizeNdExtract().value_or(false)))) {
3922+
FailureOr<VectorizationResult> vectorResults =
3923+
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3924+
getVectorizeNdExtract().value_or(false));
3925+
if (failed(vectorResults)) {
39213926
return mlir::emitSilenceableFailure(target->getLoc())
39223927
<< "Attempted to vectorize, but failed";
39233928
}
3929+
rewriter.replaceOp(target, vectorResults->replacements);
39243930
}
39253931

39263932
return DiagnosedSilenceableFailure::success();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,10 @@ enum class Conv1DOpOrder {
551551
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552552
};
553553

554-
/// Helper data structure to represent the result of vectorization.
555-
/// In certain specific cases, like terminators, we do not want to propagate/
556-
enum VectorizationStatus {
554+
/// Helper data structure to represent the result of vectorization for a single
555+
/// operation. In certain specific cases, like terminators, we do not want to
556+
/// propagate.
557+
enum VectorizationHookStatus {
557558
/// Op failed to vectorize.
558559
Failure = 0,
559560
/// Op vectorized and custom function took care of replacement logic
@@ -564,9 +565,12 @@ enum VectorizationStatus {
564565
// TODO: support values if Op vectorized to Many-Ops whose results we need to
565566
// aggregate for replacement.
566567
};
567-
struct VectorizationResult {
568+
/// VectorizationHookResult contains the vectorized op returned from a
569+
/// CustomVectorizationHook. This is an internal implementation detail of
570+
/// linalg vectorization, not to be confused with VectorizationResult.
571+
struct VectorizationHookResult {
568572
/// Return status from vectorizing the current op.
569-
enum VectorizationStatus status = VectorizationStatus::Failure;
573+
enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
570574
/// New vectorized operation to replace the current op.
571575
/// Replacement behavior is specified by `status`.
572576
Operation *newOp;
@@ -728,22 +732,22 @@ using CustomVectorizationPrecondition =
728732
// assuming all its vectorized operands are already in the IRMapping.
729733
// Return nullptr if the Operation cannot be vectorized.
730734
using CustomVectorizationHook =
731-
std::function<VectorizationResult(Operation *, const IRMapping &)>;
735+
std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
732736

733737
/// Helper function to vectorize the terminator of a `linalgOp`. New result
734738
/// vector values are appended to `newResults`. Return
735-
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
736-
/// should not try to map produced operations and instead return the results
737-
/// using the `newResults` vector making them available to the vectorization
738-
/// algorithm for RAUW. This function is meant to be used as a
739+
/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
740+
/// that it should not try to map produced operations and instead return the
741+
/// results using the `newResults` vector making them available to the
742+
/// vectorization algorithm for RAUW. This function is meant to be used as a
739743
/// CustomVectorizationHook.
740-
static VectorizationResult
744+
static VectorizationHookResult
741745
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
742746
const IRMapping &bvm, VectorizationState &state,
743747
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
744748
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
745749
if (!yieldOp)
746-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
750+
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
747751
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
748752
// TODO: Scan for an opportunity for reuse.
749753
// TODO: use a map.
@@ -755,20 +759,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
755759
newResults.push_back(newResult);
756760
}
757761

758-
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
762+
return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
759763
}
760764

761765
/// Helper function to vectorize the index operations of a `linalgOp`. Return
762-
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
766+
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
763767
/// should map the produced operations. This function is meant to be used as a
764768
/// CustomVectorizationHook.
765-
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
766-
VectorizationState &state,
767-
Operation *op,
768-
LinalgOp linalgOp) {
769+
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
770+
VectorizationState &state,
771+
Operation *op,
772+
LinalgOp linalgOp) {
769773
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
770774
if (!indexOp)
771-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
775+
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
772776
auto loc = indexOp.getLoc();
773777
// Compute the static loop sizes of the index op.
774778
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -782,7 +786,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
782786
// dimension of the iteration space since the vectorization algorithm in this
783787
// case can handle the broadcast.
784788
if (dim == targetShape.size() - 1)
785-
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
789+
return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
786790
// Otherwise permute the targetShape to move the index dimension last,
787791
// broadcast the one-dimensional index vector to the permuted shape, and
788792
// finally transpose the broadcasted index vector to undo the permutation.
@@ -800,7 +804,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
800804
std::swap(transposition.back(), transposition[dim]);
801805
auto transposeOp =
802806
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
803-
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
807+
return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
804808
}
805809

806810
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1098,15 +1102,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
10981102
}
10991103

11001104
/// Helper function to vectorize the tensor.extract operations. Returns
1101-
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1105+
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
11021106
/// should map the produced operations. This function is meant to be used as a
11031107
/// CustomVectorizationHook.
1104-
static VectorizationResult
1108+
static VectorizationHookResult
11051109
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11061110
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11071111
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11081112
if (!extractOp)
1109-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
1113+
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
11101114
auto loc = extractOp.getLoc();
11111115

11121116
// Compute the static loop sizes of the extract op.
@@ -1138,7 +1142,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11381142
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
11391143

11401144
LDBG("Vectorised as gather load: " << extractOp << "\n");
1141-
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1145+
return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
11421146
}
11431147

11441148
// 2. Handle:
@@ -1202,7 +1206,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12021206
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
12031207

12041208
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1205-
return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1209+
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1210+
maskedReadOp};
12061211
}
12071212

12081213
// 2b. Handle contiguous access.
@@ -1228,7 +1233,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12281233
inBounds);
12291234

12301235
LDBG("Vectorised as contiguous load: " << extractOp);
1231-
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1236+
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1237+
transferReadOp};
12321238
}
12331239

12341240
/// Emit reduction operations if the shapes of the value to reduce is different
@@ -1268,9 +1274,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12681274
/// This function assumes all operands of `op` have been vectorized and are in
12691275
/// the `bvm` mapping. As a consequence, this function is meant to be called on
12701276
/// a topologically-sorted list of ops.
1271-
/// This function does not update `bvm` but returns a VectorizationStatus that
1272-
/// instructs the caller what `bvm` update needs to occur.
1273-
static VectorizationResult
1277+
/// This function does not update `bvm` but returns a VectorizationHookStatus
1278+
/// that instructs the caller what `bvm` update needs to occur.
1279+
static VectorizationHookResult
12741280
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12751281
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12761282
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1279,8 +1285,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12791285
// 1. Try to apply any CustomVectorizationHook.
12801286
if (!customVectorizationHooks.empty()) {
12811287
for (auto &customFunc : customVectorizationHooks) {
1282-
VectorizationResult result = customFunc(op, bvm);
1283-
if (result.status == VectorizationStatus::Failure)
1288+
VectorizationHookResult result = customFunc(op, bvm);
1289+
if (result.status == VectorizationHookStatus::Failure)
12841290
continue;
12851291
return result;
12861292
}
@@ -1289,11 +1295,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12891295
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
12901296
// Clone so that the constant is not confined to the linalgOp block .
12911297
if (isa<arith::ConstantOp, func::ConstantOp>(op))
1292-
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1298+
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1299+
rewriter.clone(*op)};
12931300

12941301
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
12951302
if (!OpTrait::hasElementwiseMappableTraits(op))
1296-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
1303+
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
12971304

12981305
// 4 . Check if the operation is a reduction.
12991306
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1316,7 +1323,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13161323
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
13171324
reductionOperands[0].second, bvm);
13181325
if (reduceOp)
1319-
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1326+
return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
13201327
}
13211328

13221329
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1356,8 +1363,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13561363
: resultType);
13571364
}
13581365
// d. Build and return the new op.
1359-
return VectorizationResult{
1360-
VectorizationStatus::NewOp,
1366+
return VectorizationHookResult{
1367+
VectorizationHookStatus::NewOp,
13611368
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
13621369
resultTypes, op->getAttrs())};
13631370
}
@@ -1461,34 +1468,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14611468
SmallVector<CustomVectorizationHook> hooks;
14621469
// 4a. Register CustomVectorizationHook for yieldOp.
14631470
CustomVectorizationHook vectorizeYield =
1464-
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1471+
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14651472
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
14661473
};
14671474
hooks.push_back(vectorizeYield);
14681475

14691476
// 4b. Register CustomVectorizationHook for indexOp.
14701477
CustomVectorizationHook vectorizeIndex =
1471-
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1478+
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14721479
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
14731480
};
14741481
hooks.push_back(vectorizeIndex);
14751482

14761483
// 4c. Register CustomVectorizationHook for extractOp.
14771484
CustomVectorizationHook vectorizeExtract =
1478-
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1485+
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14791486
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
14801487
};
14811488
hooks.push_back(vectorizeExtract);
14821489

14831490
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14841491
for (Operation &op : block->getOperations()) {
1485-
VectorizationResult result =
1492+
VectorizationHookResult result =
14861493
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1487-
if (result.status == VectorizationStatus::Failure) {
1494+
if (result.status == VectorizationHookStatus::Failure) {
14881495
LDBG("failed to vectorize: " << op << "\n");
14891496
return failure();
14901497
}
1491-
if (result.status == VectorizationStatus::NewOp) {
1498+
if (result.status == VectorizationHookStatus::NewOp) {
14921499
Operation *maybeMaskedOp =
14931500
state.maskOperation(rewriter, result.newOp, linalgOp);
14941501
LDBG("New vector op: " << *maybeMaskedOp << "\n");
@@ -2525,17 +2532,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25252532
tensor::InsertSliceOp>(op);
25262533
}
25272534

2528-
/// Emit a suitable vector form for an operation. If provided,
2529-
/// `inputVectorSizes` are used to vectorize this operation.
2530-
/// `inputVectorSizes` must match the rank of the iteration space of the
2531-
/// operation and the input vector sizes must be greater than or equal to
2532-
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
2533-
/// also allows the vectorization of operations with dynamic shapes.
2534-
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2535-
ArrayRef<int64_t> inputVectorSizes,
2536-
ArrayRef<bool> inputScalableVecDims,
2537-
bool vectorizeNDExtract,
2538-
bool flatten1DDepthwiseConv) {
2535+
FailureOr<VectorizationResult>
2536+
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2537+
ArrayRef<int64_t> inputVectorSizes,
2538+
ArrayRef<bool> inputScalableVecDims,
2539+
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
25392540
LDBG("Attempting to vectorize:\n" << *op << "\n");
25402541
LDBG("Input vector sizes: ");
25412542
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2617,12 +2618,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26172618
return failure();
26182619
}
26192620

2620-
if (!results.empty())
2621-
rewriter.replaceOp(op, results);
2622-
else
2623-
rewriter.eraseOp(op);
2624-
2625-
return success();
2621+
return VectorizationResult{results};
26262622
}
26272623

26282624
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,

0 commit comments

Comments
 (0)