Skip to content

[MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. #112945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

#include <utility>

namespace mlir {

// Forward declarations.
Expand Down Expand Up @@ -324,6 +326,24 @@ namespace matcher {
bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);

} // namespace matcher

/// Return a value yielded by `warpOp` which statifies the filter lamdba
/// condition and is not dead.
OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
const std::function<bool(Operation *)> &fn);

/// Helper to create a new WarpExecuteOnLane0Op with different signature.
vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes);

/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
/// `indices` return the index of each new output.
vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes,
llvm::SmallVector<size_t> &indices);

Comment on lines +329 to +346
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These utilities should be in a seperate header file VectorDistributeUtils.h, also they shouldn't be in mlir namespace? These are utilities related to a specific transformation. I'm not sure if exposing them to the entire namespace is a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in #114208

} // namespace mlir

#endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace xegpu {

/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
void populateXeGPUDistributePatterns(RewritePatternSet &patterns);

} // namespace xegpu
} // namespace mlir
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
// If the types matches there is no distribution.
if (expanded == distributed)
return success();
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to generalize distribution to work on more than vector types, this whole thing needs to be moved out of vector dialect and made an interface. Checking for just ShapedType here seems like a violation.

auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
if (!expandedVecType || !distributedVecType)
return op->emitOpError("expected vector type for distributed operands.");
return op->emitOpError("expected shaped type for distributed operands.");
if (expandedVecType.getRank() != distributedVecType.getRank() ||
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
"expected distributed vectors to have same rank and element type.");
"expected distributed types to have same rank and element type.");
Comment on lines -6561 to +6568
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split out these vector op changes into a seperate patch? These are unrelated to XeGPU and are seperate from this patch. We also need to update documentation for these ops if we are doing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, #114215


SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
Expand All @@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
continue;
if (eDim % dDim != 0)
return op->emitOpError()
<< "expected expanded vector dimension #" << i << " (" << eDim
<< ") to be a multipler of the distributed vector dimension ("
<< "expected expanded type dimension #" << i << " (" << eDim
<< ") to be a multipler of the distributed type dimension ("
<< dDim << ")";
scales[i] = eDim / dDim;
}
Expand Down
80 changes: 1 addition & 79 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
Expand Down Expand Up @@ -160,92 +161,13 @@ struct DistributedLoadStoreHelper {

} // namespace

/// Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes) {
// Create a new op before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(warpOp);
auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());

Region &opBody = warpOp.getBodyRegion();
Region &newOpBody = newWarpOp.getBodyRegion();
Block &newOpFirstBlock = newOpBody.front();
rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
rewriter.eraseBlock(&newOpFirstBlock);
assert(newWarpOp.getWarpRegion().hasOneBlock() &&
"expected WarpOp with single block");

auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());

rewriter.modifyOpInPlace(
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
return newWarpOp;
}

/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
/// `indices` return the index of each new output.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes,
llvm::SmallVector<size_t> &indices) {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
yield.getOperands().end());
for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
if (yieldValues.insert(std::get<0>(newRet))) {
types.push_back(std::get<1>(newRet));
indices.push_back(yieldValues.size() - 1);
} else {
// If the value already exit the region don't create a new output.
for (auto [idx, yieldOperand] :
llvm::enumerate(yieldValues.getArrayRef())) {
if (yieldOperand == std::get<0>(newRet)) {
indices.push_back(idx);
break;
}
}
}
}
yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, yieldValues.getArrayRef(), types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
}

/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
return llvm::all_of(op->getOperands(), definedOutside) &&
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}

/// Return a value yielded by `warpOp` which statifies the filter lamdba
/// condition and is not dead.
static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
const std::function<bool(Operation *)> &fn) {
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
for (OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
if (definedOp && fn(definedOp)) {
if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
return &yieldOperand;
}
}
return {};
}

// Clones `op` into a new operation that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorUtils
VectorUtils.cpp
VectorDistributeUtils.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
Expand Down
91 changes: 91 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution -===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utility methods for working with the Vector dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Utils/VectorUtils.h"

using namespace mlir;

mlir::OpOperand *
mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
const std::function<bool(Operation *)> &fn) {
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
if (definedOp && fn(definedOp)) {
if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
return &yieldOperand;
}
}
return {};
}

vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes) {
// Create a new op before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(warpOp);
auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());

Region &opBody = warpOp.getBodyRegion();
Region &newOpBody = newWarpOp.getBodyRegion();
Block &newOpFirstBlock = newOpBody.front();
rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
rewriter.eraseBlock(&newOpFirstBlock);
assert(newWarpOp.getWarpRegion().hasOneBlock() &&
"expected WarpOp with single block");

auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());

rewriter.modifyOpInPlace(
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
return newWarpOp;
}

vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes,
llvm::SmallVector<size_t> &indices) {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
yield.getOperands().end());
for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
if (yieldValues.insert(std::get<0>(newRet))) {
types.push_back(std::get<1>(newRet));
indices.push_back(yieldValues.size() - 1);
} else {
// If the value already exit the region don't create a new output.
for (auto [idx, yieldOperand] :
llvm::enumerate(yieldValues.getArrayRef())) {
if (yieldOperand == std::get<0>(newRet)) {
indices.push_back(idx);
break;
}
}
}
}
yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
vector::WarpExecuteOnLane0Op newWarpOp =
moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
yieldValues.getArrayRef(), types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
}
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
XeGPUDistribute.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
Expand All @@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRIR
MLIRMemRefDialect
MLIRXeGPUDialect
MLIRVectorDialect
MLIRVectorUtils
MLIRArithDialect
MLIRFuncDialect
MLIRPass
MLIRTransforms
)
Loading
Loading