Skip to content

Commit 35c19fd

Browse files
[mlir][vector] Support warp distribution of transfer_read with dependencies (#77779)
Support distribution of `vector.transfer_read` ops when operands are defined inside of the region of `warp_execute_on_lane_0` (except for the buffer from which the op is reading). Such IR was previously not supported. This commit changes the implementation such that indices and the padding value are also distributed. This commit simplifies the implementation considerably: the original implementation created a new `transfer_read` op and then checked if this new op is valid. If not, the rewrite pattern failed. This was a bit hacky. It was also a violation of the rewrite pattern API (detected by `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) because the IR was modified, but the pattern returned "failure".
1 parent 4d46721 commit 35c19fd

File tree

2 files changed

+85
-73
lines changed

2 files changed

+85
-73
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 56 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -819,9 +819,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
819819
return isa<vector::TransferReadOp>(op) && op->hasOneUse();
820820
});
821821
if (!operand)
822-
return failure();
822+
return rewriter.notifyMatchFailure(
823+
warpOp, "warp result is not a vector.transfer_read op");
823824
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
824825

826+
// Source must be defined outside of the region.
827+
if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
828+
return rewriter.notifyMatchFailure(
829+
read, "source must be defined outside of the region");
830+
825831
unsigned operandIndex = operand->getOperandNumber();
826832
Value distributedVal = warpOp.getResult(operandIndex);
827833

@@ -832,10 +838,25 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
832838
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
833839
AffineMap indexMap = map.compose(read.getPermutationMap());
834840

835-
// Distribute the mask if present.
841+
// Try to delinearize the lane ID to match the rank expected for
842+
// distribution.
843+
SmallVector<Value> delinearizedIds;
844+
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
845+
distributedType.getShape(), warpOp.getWarpSize(),
846+
warpOp.getLaneid(), delinearizedIds)) {
847+
return rewriter.notifyMatchFailure(
848+
read, "cannot delinearize lane ID for distribution");
849+
}
850+
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
851+
852+
// Distribute indices and the mask (if present).
836853
OpBuilder::InsertionGuard g(rewriter);
837-
WarpExecuteOnLane0Op newWarpOp = warpOp;
838-
Value newMask = read.getMask();
854+
SmallVector<Value> additionalResults(indices.begin(), indices.end());
855+
SmallVector<Type> additionalResultTypes(indices.size(),
856+
rewriter.getIndexType());
857+
additionalResults.push_back(read.getPadding());
858+
additionalResultTypes.push_back(read.getPadding().getType());
859+
839860
bool hasMask = false;
840861
if (read.getMask()) {
841862
hasMask = true;
@@ -846,42 +867,26 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
846867
// by shape information on the warp op, and thus requires materializing
847868
// the permutation in IR.
848869
if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
849-
return failure();
870+
return rewriter.notifyMatchFailure(
871+
read, "non-trivial permutation maps not supported");
850872
VectorType maskType =
851873
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
852-
SmallVector<size_t> newRetIndices;
853-
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
854-
rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
855-
newRetIndices);
856-
newMask = newWarpOp.getResult(newRetIndices[0]);
857-
distributedVal = newWarpOp.getResult(operandIndex);
858-
} else {
859-
// This pattern does not actually change the warp op directly. Instead it
860-
// just rewrites a new transfer read (when not masked) outside of the warp
861-
// op and replaces the correponding result. There are then follow up
862-
// patterns to erase now dead results of the warp op. This erasure allows
863-
// propagation to continue, but this pattern on its own never actually
864-
// tells the pattern rewriter that the warp op "changed." Notify the
865-
// rewriter here that the warp op is changing. Similar situations are
866-
// noted in following patterns.
867-
rewriter.startRootUpdate(warpOp);
874+
additionalResults.push_back(read.getMask());
875+
additionalResultTypes.push_back(maskType);
868876
}
869877

870-
rewriter.setInsertionPointAfter(newWarpOp);
878+
SmallVector<size_t> newRetIndices;
879+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
880+
rewriter, warpOp, additionalResults, additionalResultTypes,
881+
newRetIndices);
882+
distributedVal = newWarpOp.getResult(operandIndex);
871883

872-
// Try to delinearize the lane ID to match the rank expected for
873-
// distribution.
874-
SmallVector<Value> delinearizedIds;
875-
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
876-
distributedType.getShape(), newWarpOp.getWarpSize(),
877-
newWarpOp.getLaneid(), delinearizedIds)) {
878-
if (!hasMask)
879-
rewriter.cancelRootUpdate(warpOp);
880-
return rewriter.notifyMatchFailure(
881-
read, "cannot delinearize lane ID for distribution");
882-
}
883-
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
884+
// Distributed indices were appended first.
885+
SmallVector<Value> newIndices;
886+
for (int64_t i = 0, e = indices.size(); i < e; ++i)
887+
newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
884888

889+
rewriter.setInsertionPointAfter(newWarpOp);
885890
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
886891
AffineExpr d0, d1;
887892
bindDims(read.getContext(), d0, d1);
@@ -891,42 +896,23 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
891896
unsigned indexPos = indexExpr.getPosition();
892897
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
893898
int64_t scale = distributedType.getDimSize(vectorPos);
894-
indices[indexPos] = affine::makeComposedAffineApply(
899+
newIndices[indexPos] = affine::makeComposedAffineApply(
895900
rewriter, read.getLoc(), d0 + scale * d1,
896-
{indices[indexPos], delinearizedIds[vectorPos]});
901+
{newIndices[indexPos], delinearizedIds[vectorPos]});
897902
}
903+
904+
// Distributed padding value was appended right after the indices.
905+
Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
906+
// Distributed mask value was added at the end (if the op has a mask).
907+
Value newMask =
908+
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
909+
: Value();
898910
auto newRead = rewriter.create<vector::TransferReadOp>(
899-
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
900-
read.getPermutationMapAttr(), read.getPadding(), newMask,
911+
read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
912+
read.getPermutationMapAttr(), newPadding, newMask,
901913
read.getInBoundsAttr());
902914

903-
// Check that the produced operation is legal.
904-
// The transfer op may be reading from values that are defined within
905-
// warpOp's body, which is illegal.
906-
// We do the check late because incdices may be changed by
907-
// makeComposeAffineApply. This rewrite may remove dependencies from
908-
// warpOp's body.
909-
// E.g., warpop {
910-
// %idx = affine.apply...[%outsideDef]
911-
// ... = transfer_read ...[%idx]
912-
// }
913-
// will be rewritten in:
914-
// warpop {
915-
// }
916-
// %new_idx = affine.apply...[%outsideDef]
917-
// ... = transfer_read ...[%new_idx]
918-
if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
919-
return (newRead.getMask() && value == newRead.getMask()) ||
920-
newWarpOp.isDefinedOutsideOfRegion(value);
921-
})) {
922-
if (!hasMask)
923-
rewriter.cancelRootUpdate(warpOp);
924-
return failure();
925-
}
926-
927915
rewriter.replaceAllUsesWith(distributedVal, newRead);
928-
if (!hasMask)
929-
rewriter.finalizeRootUpdate(warpOp);
930916
return success();
931917
}
932918
};
@@ -1315,6 +1301,12 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13151301
unsigned int operandNumber = operand->getOperandNumber();
13161302
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
13171303
VectorType extractSrcType = extractOp.getSourceVectorType();
1304+
// TODO: Supported shuffle types should be parameterizable, similar to
1305+
// `WarpShuffleFromIdxFn`.
1306+
if (!extractSrcType.getElementType().isF32() &&
1307+
!extractSrcType.getElementType().isInteger(32))
1308+
return rewriter.notifyMatchFailure(
1309+
extractOp, "only f32/i32 element types are supported");
13181310
bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
13191311
Type elType = extractSrcType.getElementType();
13201312
VectorType distributedVecType;

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,25 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
899899

900900
// -----
901901

902+
// Index-typed values cannot be shuffled at the moment.
903+
904+
// CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index(
905+
// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (index) {
906+
// CHECK-PROP: "some_def"
907+
// CHECK-PROP: vector.extractelement
908+
// CHECK-PROP: vector.yield {{.*}} : index
909+
// CHECK-PROP: }
910+
func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) {
911+
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (index) {
912+
%0 = "some_def"() : () -> (vector<96xindex>)
913+
%1 = vector.extractelement %0[%pos : index] : vector<96xindex>
914+
vector.yield %1 : index
915+
}
916+
return %r : index
917+
}
918+
919+
// -----
920+
902921
// CHECK-PROP: func @lane_dependent_warp_propagate_read
903922
// CHECK-PROP-SAME: %[[ID:.*]]: index
904923
func.func @lane_dependent_warp_propagate_read(
@@ -1248,12 +1267,12 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
12481267

12491268
// -----
12501269

1251-
// Check that we don't propagate transfer_reads that have dependencies on
1252-
// values inside the warp_execute_on_lane_0.
1253-
// In this case, propagating would create transfer_read that depends on the
1254-
// extractelment defined in the body.
1270+
// Make sure that all operands of the transfer_read op are properly propagated.
1271+
// The vector.extractelement op cannot be propagated because index-typed
1272+
// shuffles are not supported at the moment.
12551273

1256-
// CHECK-PROP-LABEL: func @transfer_read_no_prop(
1274+
// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
1275+
// CHECK-PROP-LABEL: func @transfer_read_prop_operands(
12571276
// CHECK-PROP-SAME: %[[IN2:[^ :]*]]: vector<1x2xindex>,
12581277
// CHECK-PROP-SAME: %[[AR1:[^ :]*]]: memref<1x4x2xi32>,
12591278
// CHECK-PROP-SAME: %[[AR2:[^ :]*]]: memref<1x4x1024xf32>)
@@ -1264,10 +1283,11 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
12641283
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
12651284
// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
12661285
// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
1267-
// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[EXTRACTELT]], %[[C0]]],
1268-
// CHECK-PROP: vector.yield %[[TRANSFERREAD]] : vector<64xf32>
1269-
// CHECK-PROP: return %[[W]]
1270-
func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
1286+
// CHECK-PROP: vector.yield %[[EXTRACTELT]] : index
1287+
// CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
1288+
// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[W]], %[[APPLY]]],
1289+
// CHECK-PROP: return %[[TRANSFERREAD]]
1290+
func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
12711291
%0 = gpu.thread_id x
12721292
%c0_i32 = arith.constant 0 : i32
12731293
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)