-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution #142618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution #142618
Changes from all commits
a808983
bc3b74b
392dfb0
449a2ed
b2032a4
bf37af1
605eee0
87934c4
689bb05
1ca7b30
e650862
3544073
8c91b39
89cac4d
f39fe3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,11 @@ | |
#include "mlir/Dialect/Index/IR/IndexDialect.h" | ||
#include "mlir/Dialect/Index/IR/IndexOps.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/SCF/Transforms/Patterns.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h" | ||
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" | ||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
|
@@ -29,6 +31,29 @@ using namespace mlir; | |
|
||
namespace { | ||
|
||
static std::pair<SmallVector<int64_t>, int> | ||
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { | ||
int count = 1; | ||
SmallVector<int64_t> sgShape(shape); | ||
|
||
if (layout && layout.isWgLayout()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isWgLayout seems confusing, I think it should be called isSgLayout since it describes how the subgroups are laid out There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interface defined in a previous PR. I think we can create a small fix PR if we plan to change it. |
||
DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); | ||
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); | ||
if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) | ||
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); | ||
else | ||
sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); | ||
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape); | ||
// Clamp distUnit to the original shape to handle cases where data is | ||
// shared among subgroups, which may cause distUnit to exceed the original | ||
// shape. | ||
for (size_t i = 0; i < distUnit.size(); ++i) | ||
distUnit[i] = std::min(shape[i], distUnit[i]); | ||
count = computeProduct(shape) / computeProduct(distUnit); | ||
} | ||
return std::make_pair(sgShape, count); | ||
} | ||
|
||
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor | ||
/// from a workgroup descriptor. It replaces the offsets and sizes with | ||
/// appropriate values for the subgroup. | ||
|
@@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { | |
return rewriter.notifyMatchFailure( | ||
op, "sgLayout attribute is required in layout"); | ||
|
||
SmallVector<int64_t> sgShape; | ||
if (auto sgDataAttr = layout.getSgData()) { | ||
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); | ||
} else { | ||
assert(wgShape.size() == sgLayout.size() && | ||
"sgLayout and wgShape must have the same rank"); | ||
sgShape.reserve(wgShape.size()); | ||
for (size_t i = 0; i < wgShape.size(); ++i) { | ||
assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero"); | ||
sgShape.push_back(wgShape[i] / sgLayout[i]); | ||
} | ||
} | ||
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; | ||
|
||
// TODO : Handle order attribute | ||
// Get the subgroup ID | ||
|
@@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { | |
if (resultTy.getRank() != 2) | ||
return failure(); | ||
|
||
auto originalLayout = | ||
llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout")); | ||
auto originalLayout = xegpu::getLayoutAttr(op.getResult()); | ||
if (!originalLayout) | ||
return failure(); | ||
|
||
SmallVector<Value> newDpasOps; | ||
size_t i = 0; | ||
SmallVector<Value> newDpasOps; | ||
for (auto aVec : adaptor.getLhs()) { | ||
for (auto bVec : adaptor.getRhs()) { | ||
|
||
llvm::SmallVector<Value> operands({aVec, bVec}); | ||
Value tmpC; | ||
if (op.getAcc()) { | ||
|
@@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { | |
llvm::cast<VectorType>(bVec.getType()).getShape(); | ||
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, | ||
resultTy.getElementType()); | ||
tmpC = rewriter.create<xegpu::DpasOp>( | ||
loc, resTy, operands, | ||
llvm::ArrayRef<NamedAttribute>( | ||
{"layout_result_0", originalLayout.dropSgLayoutAndData()})); | ||
tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands); | ||
xegpu::setLayoutAttr(cast<OpResult>(tmpC), | ||
originalLayout.dropSgLayoutAndData()); | ||
|
||
newDpasOps.push_back(tmpC); | ||
} | ||
} | ||
|
@@ -314,14 +328,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { | |
} | ||
}; | ||
|
||
// Handles UnrealizedConversionCastOp generated during | ||
// SCFStructuralTypeConversions (step 1). This op may appear as either a | ||
// target or source materialization for Vector values, e.g.: | ||
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... | ||
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> | ||
// it could be either 1:N or N:1 cast. In both cases, the pattern | ||
// simply forwards the inputs to the outputs using 1:1 or 1:N interface. | ||
// for example, the following scf::forOp | ||
// ``` | ||
// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) { | ||
// %n = use(%arg1): vector<128x128xf16> | ||
// scf.yield %n : vector<128x128xf16> | ||
// } | ||
// ``` | ||
// Could be converted to: | ||
// ``` | ||
// %1 = unrealized_conversion_cast %0 | ||
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> | ||
// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2) | ||
// -> (vector<16x16xf16>, vector<16x16xf16) { | ||
// %m = unrealized_conversion_cast %arg1, %arg2 | ||
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> | ||
// %n = use(%m): vector<128x128xf16> | ||
// %b = unrealized_conversion_cast %n | ||
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> | ||
// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16> | ||
// } | ||
// %cast = unrealized_conversion_cast %for:2 | ||
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> | ||
// ``` | ||
// TODO: remove it when context-aware type converter is ready. | ||
struct UnrealizedConversionCastOpPattern | ||
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> { | ||
using OpConversionPattern< | ||
mlir::UnrealizedConversionCastOp>::OpConversionPattern; | ||
|
||
mlir::LogicalResult | ||
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs()); | ||
|
||
auto inputTy = dyn_cast<VectorType>(inputs[0].getType()); | ||
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType()); | ||
|
||
if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) || | ||
!llvm::all_equal(ValueRange(inputs).getTypes())) | ||
return failure(); | ||
|
||
// Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...". | ||
// It is generated by source materialization (e.g., inits to scf forOp). | ||
// The input values provided by the adaptor should already be distributed, | ||
// and their types should correspond exactly to the result types of the | ||
// operation. | ||
if (op.getNumOperands() == 1 && | ||
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { | ||
rewriter.replaceOp(op, inputs); | ||
return success(); | ||
} | ||
|
||
// Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>". | ||
// It is generated by target materialization (e.g., arguments/results | ||
// of scf forOp). All input values must have the same vector type, and | ||
// their shape must be evenly divisible by the output vector's shape | ||
// (determined by the nature of the workgroup to subgroup distribution). | ||
// TODO: it is not safe to do such forward, since such N:1 cast could be | ||
charithaintc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// from others. | ||
if (op.getNumResults() == 1 && | ||
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { | ||
rewriter.replaceOpWithMultiple(op, {inputs}); | ||
return success(); | ||
} | ||
|
||
return mlir::failure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace mlir { | ||
namespace xegpu { | ||
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { | ||
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, | ||
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>( | ||
patterns.getContext()); | ||
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, | ||
UnrealizedConversionCastOpPattern>(patterns.getContext()); | ||
} | ||
} // namespace xegpu | ||
} // namespace mlir | ||
|
@@ -334,9 +424,68 @@ struct XeGPUWgToSgDistributePass | |
} // namespace | ||
|
||
void XeGPUWgToSgDistributePass::runOnOperation() { | ||
// Track existing UnrealizedConversionCastOps | ||
SmallVector<Operation *> existingCastOps; | ||
getOperation()->walk([&](UnrealizedConversionCastOp castOp) { | ||
existingCastOps.push_back(castOp.getOperation()); | ||
}); | ||
|
||
{ | ||
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with | ||
// VectorType operands. This first converts such operands to | ||
// RankedTensorType, propagates the layout attribute into the encoding | ||
// attribute, and finally converts the RankedTensorType to VectorType based | ||
// on the encoding. | ||
|
||
TypeConverter converter; | ||
converter.addConversion([&](Type type) -> Type { return type; }); | ||
converter.addConversion( | ||
[&](RankedTensorType type, | ||
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { | ||
Type elemTy = type.getElementType(); | ||
ArrayRef<int64_t> shape = type.getShape(); | ||
|
||
int count; | ||
SmallVector<int64_t> subShape; | ||
std::tie(subShape, count) = getSgShapeAndCount( | ||
shape, | ||
dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding())); | ||
|
||
auto newTy = VectorType::get(subShape, elemTy); | ||
result.append(count, newTy); | ||
return success(); | ||
}); | ||
|
||
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), | ||
converter); | ||
} | ||
|
||
// Step 2: Perform workgroup to subgroup distribution for TensorDesc values, | ||
// as well as XeGPU, Arith, and Vector operations. | ||
MLIRContext *ctx = &getContext(); | ||
RewritePatternSet patterns(ctx); | ||
ConversionTarget target(*ctx); | ||
TypeConverter converter; | ||
converter.addConversion([&](Type type) -> Type { return type; }); | ||
converter.addConversion( | ||
charithaintc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
[&](xegpu::TensorDescType type, | ||
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { | ||
Type elemTy = type.getElementType(); | ||
ArrayRef<int64_t> shape = type.getShape(); | ||
|
||
int count; | ||
SmallVector<int64_t> subShape; | ||
xegpu::LayoutAttr layout = type.getLayoutAttr(); | ||
std::tie(subShape, count) = getSgShapeAndCount(shape, layout); | ||
|
||
if (layout) | ||
layout = layout.dropSgLayoutAndData(); | ||
|
||
auto newTy = xegpu::TensorDescType::get( | ||
type.getContext(), subShape, elemTy, type.getEncoding(), layout); | ||
result.append(count, newTy); | ||
return success(); | ||
}); | ||
|
||
auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { | ||
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op)) | ||
|
@@ -353,26 +502,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() { | |
}; | ||
|
||
auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { | ||
return !layout || layout.getSgLayout() == nullptr; | ||
return !layout || !layout.isWgLayout(); | ||
}; | ||
|
||
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp, | ||
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp, | ||
xegpu::PrefetchNdOp>([=](Operation *op) -> bool { | ||
auto tdescTy = getTensorDescType(op); | ||
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout()); | ||
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout()); | ||
return isLegal(layout); | ||
}); | ||
|
||
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool { | ||
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout")); | ||
auto layout = xegpu::getLayoutAttr(op.getResult()); | ||
return isLegal(layout); | ||
}); | ||
|
||
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( | ||
[=](UnrealizedConversionCastOp op) { | ||
return llvm::is_contained(existingCastOps, op.getOperation()); | ||
}); | ||
|
||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); | ||
|
||
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, | ||
target); | ||
xegpu::populateXeGPUWgToSgDistributePatterns(patterns); | ||
if (failed( | ||
applyPartialConversion(getOperation(), target, std::move(patterns)))) | ||
return signalPassFailure(); | ||
|
||
// Remove sg_layout and sg_data attributes from the Layout | ||
// attribute for each VectorType result of the operation. | ||
// For Structured Control Flow ops, the layout is simply removed, | ||
// since in 1:N case, the layout for new results are missing. | ||
// Layout propagation pass will activated. | ||
getOperation()->walk([](Operation *op) { | ||
for (OpResult result : op->getOpResults()) { | ||
std::string name = xegpu::getLayoutName(result); | ||
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) { | ||
op->removeAttr(name); | ||
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) | ||
op->setAttr(name, layout.dropSgLayoutAndData()); | ||
} | ||
} | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this go to XeGPUUtils so XeGPU blocking could use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they are different logic. blocking is using inst_data, here it is using sg_layout and sg_data.