@@ -551,9 +551,10 @@ enum class Conv1DOpOrder {
551
551
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552
552
};
553
553
554
- // / Helper data structure to represent the result of vectorization.
555
- // / In certain specific cases, like terminators, we do not want to propagate/
556
- enum VectorizationStatus {
554
+ // / Helper data structure to represent the result of vectorization for a single
555
+ // / operation. In certain specific cases, like terminators, we do not want to
556
+ // / propagate.
557
+ enum VectorizationHookStatus {
557
558
// / Op failed to vectorize.
558
559
Failure = 0 ,
559
560
// / Op vectorized and custom function took care of replacement logic
@@ -564,9 +565,12 @@ enum VectorizationStatus {
564
565
// TODO: support values if Op vectorized to Many-Ops whose results we need to
565
566
// aggregate for replacement.
566
567
};
567
- struct VectorizationResult {
568
+ // / VectorizationHookResult contains the vectorized op returned from a
569
+ // / CustomVectorizationHook. This is an internal implementation detail of
570
+ // / linalg vectorization, not to be confused with VectorizationResult.
571
+ struct VectorizationHookResult {
568
572
// / Return status from vectorizing the current op.
569
- enum VectorizationStatus status = VectorizationStatus ::Failure;
573
+ enum VectorizationHookStatus status = VectorizationHookStatus ::Failure;
570
574
// / New vectorized operation to replace the current op.
571
575
// / Replacement behavior is specified by `status`.
572
576
Operation *newOp;
@@ -728,22 +732,22 @@ using CustomVectorizationPrecondition =
728
732
// assuming all its vectorized operands are already in the IRMapping.
729
733
// Return nullptr if the Operation cannot be vectorized.
730
734
using CustomVectorizationHook =
731
- std::function<VectorizationResult (Operation *, const IRMapping &)>;
735
+ std::function<VectorizationHookResult (Operation *, const IRMapping &)>;
732
736
733
737
// / Helper function to vectorize the terminator of a `linalgOp`. New result
734
738
// / vector values are appended to `newResults`. Return
735
- // / VectorizationStatus ::NoReplace to signal the vectorization algorithm that it
736
- // / should not try to map produced operations and instead return the results
737
- // / using the `newResults` vector making them available to the vectorization
738
- // / algorithm for RAUW. This function is meant to be used as a
739
+ // / VectorizationHookStatus ::NoReplace to signal the vectorization algorithm
740
+ // / that it should not try to map produced operations and instead return the
741
+ // / results using the `newResults` vector making them available to the
742
+ // / vectorization algorithm for RAUW. This function is meant to be used as a
739
743
// / CustomVectorizationHook.
740
- static VectorizationResult
744
+ static VectorizationHookResult
741
745
vectorizeLinalgYield (RewriterBase &rewriter, Operation *op,
742
746
const IRMapping &bvm, VectorizationState &state,
743
747
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
744
748
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
745
749
if (!yieldOp)
746
- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
750
+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
747
751
for (const auto &output : llvm::enumerate (yieldOp.getValues ())) {
748
752
// TODO: Scan for an opportunity for reuse.
749
753
// TODO: use a map.
@@ -755,20 +759,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
755
759
newResults.push_back (newResult);
756
760
}
757
761
758
- return VectorizationResult{VectorizationStatus ::NoReplace, nullptr };
762
+ return VectorizationHookResult{VectorizationHookStatus ::NoReplace, nullptr };
759
763
}
760
764
761
765
// / Helper function to vectorize the index operations of a `linalgOp`. Return
762
- // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
766
+ // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
763
767
// / should map the produced operations. This function is meant to be used as a
764
768
// / CustomVectorizationHook.
765
- static VectorizationResult vectorizeLinalgIndex (RewriterBase &rewriter,
766
- VectorizationState &state,
767
- Operation *op,
768
- LinalgOp linalgOp) {
769
+ static VectorizationHookResult vectorizeLinalgIndex (RewriterBase &rewriter,
770
+ VectorizationState &state,
771
+ Operation *op,
772
+ LinalgOp linalgOp) {
769
773
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
770
774
if (!indexOp)
771
- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
775
+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
772
776
auto loc = indexOp.getLoc ();
773
777
// Compute the static loop sizes of the index op.
774
778
ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
@@ -782,7 +786,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
782
786
// dimension of the iteration space since the vectorization algorithm in this
783
787
// case can handle the broadcast.
784
788
if (dim == targetShape.size () - 1 )
785
- return VectorizationResult{VectorizationStatus ::NewOp, indexSteps};
789
+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, indexSteps};
786
790
// Otherwise permute the targetShape to move the index dimension last,
787
791
// broadcast the one-dimensional index vector to the permuted shape, and
788
792
// finally transpose the broadcasted index vector to undo the permutation.
@@ -800,7 +804,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
800
804
std::swap (transposition.back (), transposition[dim]);
801
805
auto transposeOp =
802
806
rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
803
- return VectorizationResult{VectorizationStatus ::NewOp, transposeOp};
807
+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, transposeOp};
804
808
}
805
809
806
810
// / Helper function to check if the tensor.extract can be vectorized by the
@@ -1098,15 +1102,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1098
1102
}
1099
1103
1100
1104
// / Helper function to vectorize the tensor.extract operations. Returns
1101
- // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
1105
+ // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
1102
1106
// / should map the produced operations. This function is meant to be used as a
1103
1107
// / CustomVectorizationHook.
1104
- static VectorizationResult
1108
+ static VectorizationHookResult
1105
1109
vectorizeTensorExtract (RewriterBase &rewriter, VectorizationState &state,
1106
1110
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1107
1111
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1108
1112
if (!extractOp)
1109
- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
1113
+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1110
1114
auto loc = extractOp.getLoc ();
1111
1115
1112
1116
// Compute the static loop sizes of the extract op.
@@ -1138,7 +1142,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1138
1142
gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
1139
1143
1140
1144
LDBG (" Vectorised as gather load: " << extractOp << " \n " );
1141
- return VectorizationResult{VectorizationStatus ::NewOp, gatherOp};
1145
+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, gatherOp};
1142
1146
}
1143
1147
1144
1148
// 2. Handle:
@@ -1202,7 +1206,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1202
1206
mlir::vector::maskOperation (rewriter, transferReadOp, allTrue);
1203
1207
1204
1208
LDBG (" Vectorised as scalar broadcast load: " << extractOp << " \n " );
1205
- return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1209
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
1210
+ maskedReadOp};
1206
1211
}
1207
1212
1208
1213
// 2b. Handle contiguous access.
@@ -1228,7 +1233,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1228
1233
inBounds);
1229
1234
1230
1235
LDBG (" Vectorised as contiguous load: " << extractOp);
1231
- return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1236
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
1237
+ transferReadOp};
1232
1238
}
1233
1239
1234
1240
// / Emit reduction operations if the shapes of the value to reduce is different
@@ -1268,9 +1274,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1268
1274
// / This function assumes all operands of `op` have been vectorized and are in
1269
1275
// / the `bvm` mapping. As a consequence, this function is meant to be called on
1270
1276
// / a topologically-sorted list of ops.
1271
- // / This function does not update `bvm` but returns a VectorizationStatus that
1272
- // / instructs the caller what `bvm` update needs to occur.
1273
- static VectorizationResult
1277
+ // / This function does not update `bvm` but returns a VectorizationHookStatus
1278
+ // / that instructs the caller what `bvm` update needs to occur.
1279
+ static VectorizationHookResult
1274
1280
vectorizeOneOp (RewriterBase &rewriter, VectorizationState &state,
1275
1281
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1276
1282
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1279,8 +1285,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1279
1285
// 1. Try to apply any CustomVectorizationHook.
1280
1286
if (!customVectorizationHooks.empty ()) {
1281
1287
for (auto &customFunc : customVectorizationHooks) {
1282
- VectorizationResult result = customFunc (op, bvm);
1283
- if (result.status == VectorizationStatus ::Failure)
1288
+ VectorizationHookResult result = customFunc (op, bvm);
1289
+ if (result.status == VectorizationHookStatus ::Failure)
1284
1290
continue ;
1285
1291
return result;
1286
1292
}
@@ -1289,11 +1295,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1289
1295
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
1290
1296
// Clone so that the constant is not confined to the linalgOp block .
1291
1297
if (isa<arith::ConstantOp, func::ConstantOp>(op))
1292
- return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone (*op)};
1298
+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
1299
+ rewriter.clone (*op)};
1293
1300
1294
1301
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
1295
1302
if (!OpTrait::hasElementwiseMappableTraits (op))
1296
- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
1303
+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1297
1304
1298
1305
// 4 . Check if the operation is a reduction.
1299
1306
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1316,7 +1323,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1316
1323
reduceIfNeeded (rewriter, linalgOp, op, reductionOperands[0 ].first ,
1317
1324
reductionOperands[0 ].second , bvm);
1318
1325
if (reduceOp)
1319
- return VectorizationResult{VectorizationStatus ::NewOp, reduceOp};
1326
+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, reduceOp};
1320
1327
}
1321
1328
1322
1329
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1356,8 +1363,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1356
1363
: resultType);
1357
1364
}
1358
1365
// d. Build and return the new op.
1359
- return VectorizationResult {
1360
- VectorizationStatus ::NewOp,
1366
+ return VectorizationHookResult {
1367
+ VectorizationHookStatus ::NewOp,
1361
1368
rewriter.create (op->getLoc (), op->getName ().getIdentifier (), vecOperands,
1362
1369
resultTypes, op->getAttrs ())};
1363
1370
}
@@ -1461,34 +1468,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1461
1468
SmallVector<CustomVectorizationHook> hooks;
1462
1469
// 4a. Register CustomVectorizationHook for yieldOp.
1463
1470
CustomVectorizationHook vectorizeYield =
1464
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1471
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1465
1472
return vectorizeLinalgYield (rewriter, op, bvm, state, linalgOp, newResults);
1466
1473
};
1467
1474
hooks.push_back (vectorizeYield);
1468
1475
1469
1476
// 4b. Register CustomVectorizationHook for indexOp.
1470
1477
CustomVectorizationHook vectorizeIndex =
1471
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1478
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1472
1479
return vectorizeLinalgIndex (rewriter, state, op, linalgOp);
1473
1480
};
1474
1481
hooks.push_back (vectorizeIndex);
1475
1482
1476
1483
// 4c. Register CustomVectorizationHook for extractOp.
1477
1484
CustomVectorizationHook vectorizeExtract =
1478
- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1485
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1479
1486
return vectorizeTensorExtract (rewriter, state, op, linalgOp, bvm);
1480
1487
};
1481
1488
hooks.push_back (vectorizeExtract);
1482
1489
1483
1490
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1484
1491
for (Operation &op : block->getOperations ()) {
1485
- VectorizationResult result =
1492
+ VectorizationHookResult result =
1486
1493
vectorizeOneOp (rewriter, state, linalgOp, &op, bvm, hooks);
1487
- if (result.status == VectorizationStatus ::Failure) {
1494
+ if (result.status == VectorizationHookStatus ::Failure) {
1488
1495
LDBG (" failed to vectorize: " << op << " \n " );
1489
1496
return failure ();
1490
1497
}
1491
- if (result.status == VectorizationStatus ::NewOp) {
1498
+ if (result.status == VectorizationHookStatus ::NewOp) {
1492
1499
Operation *maybeMaskedOp =
1493
1500
state.maskOperation (rewriter, result.newOp , linalgOp);
1494
1501
LDBG (" New vector op: " << *maybeMaskedOp << " \n " );
@@ -2525,17 +2532,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2525
2532
tensor::InsertSliceOp>(op);
2526
2533
}
2527
2534
2528
- // / Emit a suitable vector form for an operation. If provided,
2529
- // / `inputVectorSizes` are used to vectorize this operation.
2530
- // / `inputVectorSizes` must match the rank of the iteration space of the
2531
- // / operation and the input vector sizes must be greater than or equal to
2532
- // / their counterpart iteration space sizes, if static. `inputVectorShapes`
2533
- // / also allows the vectorization of operations with dynamic shapes.
2534
- LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2535
- ArrayRef<int64_t > inputVectorSizes,
2536
- ArrayRef<bool > inputScalableVecDims,
2537
- bool vectorizeNDExtract,
2538
- bool flatten1DDepthwiseConv) {
2535
+ FailureOr<VectorizationResult>
2536
+ mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2537
+ ArrayRef<int64_t > inputVectorSizes,
2538
+ ArrayRef<bool > inputScalableVecDims,
2539
+ bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2539
2540
LDBG (" Attempting to vectorize:\n " << *op << " \n " );
2540
2541
LDBG (" Input vector sizes: " );
2541
2542
LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2617,12 +2618,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2617
2618
return failure ();
2618
2619
}
2619
2620
2620
- if (!results.empty ())
2621
- rewriter.replaceOp (op, results);
2622
- else
2623
- rewriter.eraseOp (op);
2624
-
2625
- return success ();
2621
+ return VectorizationResult{results};
2626
2622
}
2627
2623
2628
2624
LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
0 commit comments