Skip to content

[mlir][Vector] Move insert/extractelement distribution patterns to insert/extract #116425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 133 additions & 107 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType extractSrcType = extractOp.getSourceVectorType();
Location loc = extractOp.getLoc();

// "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
assert(extractSrcType.getRank() > 0 &&
"vector.extract does not support rank 0 sources");

// "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
// canonicalized to %v.
if (extractOp.getNumIndices() == 0)
// For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
if (extractSrcType.getRank() <= 1) {
return failure();

// Rewrite vector.extract with 1d source to vector.extractelement.
if (extractSrcType.getRank() == 1) {
if (extractOp.hasDynamicPosition())
// TODO: Dinamic position not supported yet.
return failure();

assert(extractOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
return success();
}

// All following cases are 2d or higher dimensional source vectors.
Expand Down Expand Up @@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};

/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
PatternBenefit b = 1)
/// Pattern to move out vector.extract with a scalar result.
/// Only supports 1-D and 0-D sources for now.
struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
// Only supports 1-D or 0-D sources for now.
if (extractSrcType.getRank() > 1) {
return rewriter.notifyMatchFailure(
extractOp, "only 0-D or 1-D source supported for now");
}
// TODO: Supported shuffle types should be parameterizable, similar to
// `WarpShuffleFromIdxFn`.
if (!extractSrcType.getElementType().isF32() &&
Expand All @@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType distributedVecType;
if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
"expected that extractelement src rank is 0 or 1");
"expected that extract src rank is 0 or 1");
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
return failure();
int64_t elementsPerLane =
Expand All @@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Yield source vector and position (if present) from warp op.
SmallVector<Value> additionalResults{extractOp.getVector()};
SmallVector<Type> additionalResultTypes{distributedVecType};
if (static_cast<bool>(extractOp.getPosition())) {
additionalResults.push_back(extractOp.getPosition());
additionalResultTypes.push_back(extractOp.getPosition().getType());
}
additionalResults.append(
SmallVector<Value>(extractOp.getDynamicPosition()));
additionalResultTypes.append(
SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));

Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
Expand All @@ -1368,39 +1355,33 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// All lanes extract the scalar.
if (is0dOrVec1Extract) {
Value newExtract;
if (extractSrcType.getRank() == 1) {
newExtract = rewriter.create<vector::ExtractElementOp>(
loc, distributedVec,
rewriter.create<arith::ConstantIndexOp>(loc, 0));

} else {
newExtract =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
}
SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
newExtract =
rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}

int64_t staticPos = extractOp.getStaticPosition()[0];
OpFoldResult pos = ShapedType::isDynamic(staticPos)
? (newWarpOp->getResult(newRetIndices[1]))
: OpFoldResult(rewriter.getIndexAttr(staticPos));
// 1d extract: Distribute the source vector. One lane extracts and shuffles
// the value to all other lanes.
int64_t elementsPerLane = distributedVecType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
loc, sym0.ceilDiv(elementsPerLane),
newWarpOp->getResult(newRetIndices[1]));
Value broadcastFromTid = affine::makeComposedAffineApply(
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Extract at position: pos % elementsPerLane
Value pos =
Value newPos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: rewriter
.create<affine::AffineApplyOp>(
loc, sym0 % elementsPerLane,
newWarpOp->getResult(newRetIndices[1]))
.getResult();
: affine::makeComposedAffineApply(rewriter, loc,
sym0 % elementsPerLane, pos);
Value extracted =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);

// Shuffle the extracted value to all lanes.
Value shuffled = warpShuffleFromIdxFn(
Expand All @@ -1413,31 +1394,59 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};

struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
/// Pattern to convert vector.extractelement to vector.extract.
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
if (!operand)
return failure();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
SmallVector<OpFoldResult> indices;
if (auto pos = extractOp.getPosition()) {
indices.push_back(pos);
}
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getVector(), indices);
return success();
}
};

/// Pattern to move out vector.insert with a scalar input.
/// Only supports 1-D and 0-D destinations for now.
struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;

LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
bool hasPos = static_cast<bool>(insertOp.getPosition());

// Only supports 1-D or 0-D destinations for now.
if (vecType.getRank() > 1) {
return rewriter.notifyMatchFailure(
insertOp, "only 0-D or 1-D source supported for now");
}

// Yield destination vector, source scalar and position from warp op.
SmallVector<Value> additionalResults{insertOp.getDest(),
insertOp.getSource()};
SmallVector<Type> additionalResultTypes{distrType,
insertOp.getSource().getType()};
if (hasPos) {
additionalResults.push_back(insertOp.getPosition());
additionalResultTypes.push_back(insertOp.getPosition().getType());
}
additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
additionalResultTypes.append(
SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));

Location loc = insertOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
Expand All @@ -1446,13 +1455,26 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newSource = newWarpOp->getResult(newRetIndices[1]);
Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
rewriter.setInsertionPointAfter(newWarpOp);

OpFoldResult pos;
if (vecType.getRank() != 0) {
int64_t staticPos = insertOp.getStaticPosition()[0];
pos = ShapedType::isDynamic(staticPos)
? (newWarpOp->getResult(newRetIndices[2]))
: OpFoldResult(rewriter.getIndexAttr(staticPos));
}

// This condition is always true for 0-d vectors.
if (vecType == distrType) {
// Broadcast: Simply move the vector.inserelement op out.
Value newInsert = rewriter.create<vector::InsertElementOp>(
loc, newSource, distributedVec, newPos);
Value newInsert;
SmallVector<OpFoldResult> indices;
if (pos) {
indices.push_back(pos);
}
newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
distributedVec, indices);
// Broadcast: Simply move the vector.insert op out.
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
return success();
Expand All @@ -1462,16 +1484,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
int64_t elementsPerLane = distrType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
Value insertingLane = rewriter.create<affine::AffineApplyOp>(
loc, sym0.ceilDiv(elementsPerLane), newPos);
Value insertingLane = affine::makeComposedAffineApply(
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Insert position: pos % elementsPerLane
Value pos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: rewriter
.create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
newPos)
.getResult();
OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
rewriter, loc, sym0 % elementsPerLane, pos);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
Value newResult =
Expand All @@ -1480,8 +1497,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertElementOp>(
loc, newSource, distributedVec, pos);
Value newInsert = builder.create<vector::InsertOp>(
loc, newSource, distributedVec, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
},
/*elseBuilder=*/
Expand All @@ -1506,25 +1523,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
Location loc = insertOp.getLoc();

// "vector.insert %v, %v[] : ..." can be canonicalized to %v.
if (insertOp.getNumIndices() == 0)
// For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
if (insertOp.getDestVectorType().getRank() <= 1) {
return failure();

// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
if (insertOp.hasDynamicPosition())
// TODO: Dinamic position not supported yet.
return failure();

assert(insertOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
return success();
}

// All following cases are 2d or higher dimensional source vectors.

if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
// There is no distribution, this is a broadcast. Simply move the insert
// out of the warp op.
Expand Down Expand Up @@ -1620,9 +1625,30 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};

struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;

LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
if (!operand)
return failure();
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
SmallVector<OpFoldResult> indices;
if (auto pos = insertOp.getPosition()) {
indices.push_back(pos);
}
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertOp>(
insertOp, insertOp.getSource(), insertOp.getDest(), indices);
return success();
}
};

/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
/// WarpExecuteOnLane0Op region. Example:
/// ```
Expand Down Expand Up @@ -1668,8 +1694,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
// Those Value needs to be returned by the original warpOp and passed to the
// new op.
// Those Value needs to be returned by the original warpOp and passed to
// the new op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
Expand Down Expand Up @@ -1715,8 +1741,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);

// Create a new for op outside the region with a WarpExecuteOnLane0Op region
// inside.
// Create a new for op outside the region with a WarpExecuteOnLane0Op
// region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
Expand Down Expand Up @@ -1778,8 +1804,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
};

/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
/// The vector is reduced in parallel. Currently limited to vector size matching
/// the warpOp size. E.g.:
/// The vector is reduced in parallel. Currently limited to vector size
/// matching the warpOp size. E.g.:
/// ```
/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
/// %0 = "some_def"() : () -> (vector<32xf32>)
Expand Down Expand Up @@ -1880,13 +1906,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
patterns
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
Expand Down
Loading
Loading