Skip to content

[mlir][vector] Support warp distribution of transfer_read with dependencies #77779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 56 additions & 64 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
return isa<vector::TransferReadOp>(op) && op->hasOneUse();
});
if (!operand)
return failure();
return rewriter.notifyMatchFailure(
warpOp, "warp result is not a vector.transfer_read op");
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();

// Source must be defined outside of the region.
if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
return rewriter.notifyMatchFailure(
read, "source must be defined outside of the region");

unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);

Expand All @@ -832,10 +838,25 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());

// Distribute the mask if present.
// Try to delinearize the lane ID to match the rank expected for
// distribution.
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked at the implementation of delinearizeLaneId (not for this commit).
It looks like it is duplicating functionality from Dialect/Utils/IndexingUtils.
Refactoring to reuse/extend IndexingUtils would also be most welcome.

distributedType.getShape(), warpOp.getWarpSize(),
warpOp.getLaneid(), delinearizedIds)) {
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
}
assert(!delinearizedIds.empty() || map.getNumResults() == 0);

// Distribute indices and the mask (if present).
OpBuilder::InsertionGuard g(rewriter);
WarpExecuteOnLane0Op newWarpOp = warpOp;
Value newMask = read.getMask();
SmallVector<Value> additionalResults(indices.begin(), indices.end());
SmallVector<Type> additionalResultTypes(indices.size(),
rewriter.getIndexType());
additionalResults.push_back(read.getPadding());
additionalResultTypes.push_back(read.getPadding().getType());

bool hasMask = false;
if (read.getMask()) {
hasMask = true;
Expand All @@ -846,42 +867,26 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
// by shape information on the warp op, and thus requires materializing
// the permutation in IR.
if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
return failure();
return rewriter.notifyMatchFailure(
read, "non-trivial permutation maps not supported");
VectorType maskType =
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
SmallVector<size_t> newRetIndices;
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
newRetIndices);
newMask = newWarpOp.getResult(newRetIndices[0]);
distributedVal = newWarpOp.getResult(operandIndex);
} else {
// This pattern does not actually change the warp op directly. Instead it
// just rewrites a new transfer read (when not masked) outside of the warp
// op and replaces the correponding result. There are then follow up
// patterns to erase now dead results of the warp op. This erasure allows
// propagation to continue, but this pattern on its own never actually
// tells the pattern rewriter that the warp op "changed." Notify the
// rewriter here that the warp op is changing. Similar situations are
// noted in following patterns.
rewriter.startRootUpdate(warpOp);
additionalResults.push_back(read.getMask());
additionalResultTypes.push_back(maskType);
}

rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
distributedVal = newWarpOp.getResult(operandIndex);

// Try to delinearize the lane ID to match the rank expected for
// distribution.
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
distributedType.getShape(), newWarpOp.getWarpSize(),
newWarpOp.getLaneid(), delinearizedIds)) {
if (!hasMask)
rewriter.cancelRootUpdate(warpOp);
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
}
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
// Distributed indices were appended first.
SmallVector<Value> newIndices;
for (int64_t i = 0, e = indices.size(); i < e; ++i)
newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));

rewriter.setInsertionPointAfter(newWarpOp);
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
Expand All @@ -891,42 +896,23 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
int64_t scale = distributedType.getDimSize(vectorPos);
indices[indexPos] = affine::makeComposedAffineApply(
newIndices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{indices[indexPos], delinearizedIds[vectorPos]});
{newIndices[indexPos], delinearizedIds[vectorPos]});
}

// Distributed padding value was appended right after the indices.
Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
// Distributed mask value was added at the end (if the op has a mask).
Value newMask =
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
: Value();
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
read.getPermutationMapAttr(), read.getPadding(), newMask,
read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
read.getPermutationMapAttr(), newPadding, newMask,
read.getInBoundsAttr());

// Check that the produced operation is legal.
// The transfer op may be reading from values that are defined within
// warpOp's body, which is illegal.
// We do the check late because incdices may be changed by
// makeComposeAffineApply. This rewrite may remove dependencies from
// warpOp's body.
// E.g., warpop {
// %idx = affine.apply...[%outsideDef]
// ... = transfer_read ...[%idx]
// }
// will be rewritten in:
// warpop {
// }
// %new_idx = affine.apply...[%outsideDef]
// ... = transfer_read ...[%new_idx]
if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
return (newRead.getMask() && value == newRead.getMask()) ||
newWarpOp.isDefinedOutsideOfRegion(value);
})) {
if (!hasMask)
rewriter.cancelRootUpdate(warpOp);
return failure();
}

rewriter.replaceAllUsesWith(distributedVal, newRead);
if (!hasMask)
rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
Expand Down Expand Up @@ -1315,6 +1301,12 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
// TODO: Supported shuffle types should be parameterizable, similar to
// `WarpShuffleFromIdxFn`.
if (!extractSrcType.getElementType().isF32() &&
!extractSrcType.getElementType().isInteger(32))
return rewriter.notifyMatchFailure(
extractOp, "only f32/i32 element types are supported");
bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
Type elType = extractSrcType.getElementType();
VectorType distributedVecType;
Expand Down
38 changes: 29 additions & 9 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,25 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {

// -----

// Index-typed values cannot be shuffled at the moment.

// CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index(
// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (index) {
// CHECK-PROP: "some_def"
// CHECK-PROP: vector.extractelement
// CHECK-PROP: vector.yield {{.*}} : index
// CHECK-PROP: }
func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (index) {
%0 = "some_def"() : () -> (vector<96xindex>)
%1 = vector.extractelement %0[%pos : index] : vector<96xindex>
vector.yield %1 : index
}
return %r : index
}

// -----

// CHECK-PROP: func @lane_dependent_warp_propagate_read
// CHECK-PROP-SAME: %[[ID:.*]]: index
func.func @lane_dependent_warp_propagate_read(
Expand Down Expand Up @@ -1248,12 +1267,12 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {

// -----

// Check that we don't propagate transfer_reads that have dependencies on
// values inside the warp_execute_on_lane_0.
// In this case, propagating would create transfer_read that depends on the
// extractelment defined in the body.
// Make sure that all operands of the transfer_read op are properly propagated.
// The vector.extractelement op cannot be propagated because index-typed
// shuffles are not supported at the moment.

// CHECK-PROP-LABEL: func @transfer_read_no_prop(
// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK-PROP-LABEL: func @transfer_read_prop_operands(
// CHECK-PROP-SAME: %[[IN2:[^ :]*]]: vector<1x2xindex>,
// CHECK-PROP-SAME: %[[AR1:[^ :]*]]: memref<1x4x2xi32>,
// CHECK-PROP-SAME: %[[AR2:[^ :]*]]: memref<1x4x1024xf32>)
Expand All @@ -1264,10 +1283,11 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[EXTRACTELT]], %[[C0]]],
// CHECK-PROP: vector.yield %[[TRANSFERREAD]] : vector<64xf32>
// CHECK-PROP: return %[[W]]
func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
// CHECK-PROP: vector.yield %[[EXTRACTELT]] : index
// CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[W]], %[[APPLY]]],
// CHECK-PROP: return %[[TRANSFERREAD]]
func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
%0 = gpu.thread_id x
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
Expand Down