-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. #112945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d292605
7627cf7
de3ae89
f907be6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, | |
// If the types matches there is no distribution. | ||
if (expanded == distributed) | ||
return success(); | ||
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded); | ||
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed); | ||
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to generalize distribution to work on more than vector types, this whole thing needs to be moved out of vector dialect and made an interface. Checking for just |
||
auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed); | ||
if (!expandedVecType || !distributedVecType) | ||
return op->emitOpError("expected vector type for distributed operands."); | ||
return op->emitOpError("expected shaped type for distributed operands."); | ||
if (expandedVecType.getRank() != distributedVecType.getRank() || | ||
expandedVecType.getElementType() != distributedVecType.getElementType()) | ||
return op->emitOpError( | ||
"expected distributed vectors to have same rank and element type."); | ||
"expected distributed types to have same rank and element type."); | ||
Comment on lines
-6561
to
+6568
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you split out these vector op changes into a seperate patch? These are unrelated to XeGPU and are seperate from this patch. We also need to update documentation for these ops if we are doing this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, #114215 |
||
|
||
SmallVector<int64_t> scales(expandedVecType.getRank(), 1); | ||
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { | ||
|
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, | |
continue; | ||
if (eDim % dDim != 0) | ||
return op->emitOpError() | ||
<< "expected expanded vector dimension #" << i << " (" << eDim | ||
<< ") to be a multipler of the distributed vector dimension (" | ||
<< "expected expanded type dimension #" << i << " (" << eDim | ||
<< ") to be a multipler of the distributed type dimension (" | ||
<< dDim << ")"; | ||
scales[i] = eDim / dDim; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution -===// | ||
// | ||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements utility methods for working with the Vector dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
|
||
using namespace mlir; | ||
|
||
mlir::OpOperand * | ||
mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp, | ||
const std::function<bool(Operation *)> &fn) { | ||
auto yield = cast<vector::YieldOp>( | ||
warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||
for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) { | ||
Value yieldValues = yieldOperand.get(); | ||
Operation *definedOp = yieldValues.getDefiningOp(); | ||
if (definedOp && fn(definedOp)) { | ||
if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) | ||
return &yieldOperand; | ||
} | ||
} | ||
return {}; | ||
} | ||
|
||
vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns( | ||
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, | ||
ValueRange newYieldedValues, TypeRange newReturnTypes) { | ||
// Create a new op before the existing one, with the extra operands. | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(warpOp); | ||
auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>( | ||
warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), | ||
warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); | ||
|
||
Region &opBody = warpOp.getBodyRegion(); | ||
Region &newOpBody = newWarpOp.getBodyRegion(); | ||
Block &newOpFirstBlock = newOpBody.front(); | ||
rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); | ||
rewriter.eraseBlock(&newOpFirstBlock); | ||
assert(newWarpOp.getWarpRegion().hasOneBlock() && | ||
"expected WarpOp with single block"); | ||
|
||
auto yield = | ||
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); | ||
|
||
rewriter.modifyOpInPlace( | ||
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); | ||
return newWarpOp; | ||
} | ||
|
||
vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns( | ||
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, | ||
ValueRange newYieldedValues, TypeRange newReturnTypes, | ||
llvm::SmallVector<size_t> &indices) { | ||
SmallVector<Type> types(warpOp.getResultTypes().begin(), | ||
warpOp.getResultTypes().end()); | ||
auto yield = cast<vector::YieldOp>( | ||
warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||
llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), | ||
yield.getOperands().end()); | ||
for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { | ||
if (yieldValues.insert(std::get<0>(newRet))) { | ||
types.push_back(std::get<1>(newRet)); | ||
indices.push_back(yieldValues.size() - 1); | ||
} else { | ||
// If the value already exit the region don't create a new output. | ||
for (auto [idx, yieldOperand] : | ||
llvm::enumerate(yieldValues.getArrayRef())) { | ||
if (yieldOperand == std::get<0>(newRet)) { | ||
indices.push_back(idx); | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); | ||
vector::WarpExecuteOnLane0Op newWarpOp = | ||
moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp, | ||
yieldValues.getArrayRef(), types); | ||
rewriter.replaceOp(warpOp, | ||
newWarpOp.getResults().take_front(warpOp.getNumResults())); | ||
return newWarpOp; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These utilities should be in a seperate header file VectorDistributeUtils.h, also they shouldn't be in mlir namespace? These are utilities related to a specific transformation. I'm not sure if exposing them to the entire namespace is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in #114208