Skip to content

[mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution #142618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 13, 2025
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class TensorDescType;

namespace xegpu {

/// Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);

/// If tensor descriptor has a layout attribute it is used in SIMT mode.
/// In this mode, the distributed vector shape is determined as follows:
/// Definitions:
Expand Down
220 changes: 196 additions & 24 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
Expand All @@ -29,6 +31,29 @@ using namespace mlir;

namespace {

static std::pair<SmallVector<int64_t>, int>
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this go to XeGPUUtils so XeGPU blocking could use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are different logic. blocking is using inst_data, here it is using sg_layout and sg_data.

int count = 1;
SmallVector<int64_t> sgShape(shape);

if (layout && layout.isWgLayout()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isWgLayout seems confusing, I think it should be called isSgLayout since it describes how the subgroups are laid out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interface defined in a previous PR. I think we can create a small fix PR if we plan to change it.

DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
else
sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
// shape.
for (size_t i = 0; i < distUnit.size(); ++i)
distUnit[i] = std::min(shape[i], distUnit[i]);
count = computeProduct(shape) / computeProduct(distUnit);
}
return std::make_pair(sgShape, count);
}

/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
Expand Down Expand Up @@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");

SmallVector<int64_t> sgShape;
if (auto sgDataAttr = layout.getSgData()) {
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
} else {
assert(wgShape.size() == sgLayout.size() &&
"sgLayout and wgShape must have the same rank");
sgShape.reserve(wgShape.size());
for (size_t i = 0; i < wgShape.size(); ++i) {
assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
sgShape.push_back(wgShape[i] / sgLayout[i]);
}
}
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

// TODO : Handle order attribute
// Get the subgroup ID
Expand Down Expand Up @@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
if (resultTy.getRank() != 2)
return failure();

auto originalLayout =
llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
auto originalLayout = xegpu::getLayoutAttr(op.getResult());
if (!originalLayout)
return failure();

SmallVector<Value> newDpasOps;
size_t i = 0;
SmallVector<Value> newDpasOps;
for (auto aVec : adaptor.getLhs()) {
for (auto bVec : adaptor.getRhs()) {

llvm::SmallVector<Value> operands({aVec, bVec});
Value tmpC;
if (op.getAcc()) {
Expand All @@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
llvm::cast<VectorType>(bVec.getType()).getShape();
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
resultTy.getElementType());
tmpC = rewriter.create<xegpu::DpasOp>(
loc, resTy, operands,
llvm::ArrayRef<NamedAttribute>(
{"layout_result_0", originalLayout.dropSgLayoutAndData()}));
tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
xegpu::setLayoutAttr(cast<OpResult>(tmpC),
originalLayout.dropSgLayoutAndData());

newDpasOps.push_back(tmpC);
}
}
Expand All @@ -314,14 +328,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
// it could be either 1:N or N:1 cast. In both cases, the pattern
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
// for example, the following scf::forOp
// ```
// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
// %n = use(%arg1): vector<128x128xf16>
// scf.yield %n : vector<128x128xf16>
// }
// ```
// Could be converted to:
// ```
// %1 = unrealized_conversion_cast %0
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
// -> (vector<16x16xf16>, vector<16x16xf16) {
// %m = unrealized_conversion_cast %arg1, %arg2
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
// %n = use(%m): vector<128x128xf16>
// %b = unrealized_conversion_cast %n
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
// }
// %cast = unrealized_conversion_cast %for:2
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
// ```
// TODO: remove it when context-aware type converter is ready.
struct UnrealizedConversionCastOpPattern
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
using OpConversionPattern<
mlir::UnrealizedConversionCastOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());

auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());

if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
!llvm::all_equal(ValueRange(inputs).getTypes()))
return failure();

// Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
// It is generated by source materialization (e.g., inits to scf forOp).
// The input values provided by the adaptor should already be distributed,
// and their types should correspond exactly to the result types of the
// operation.
if (op.getNumOperands() == 1 &&
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
rewriter.replaceOp(op, inputs);
return success();
}

// Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
// It is generated by target materialization (e.g., arguments/results
// of scf forOp). All input values must have the same vector type, and
// their shape must be evenly divisible by the output vector's shape
// (determined by the nature of the workgroup to subgroup distribution).
// TODO: it is not safe to do such forward, since such N:1 cast could be
// from others.
if (op.getNumResults() == 1 &&
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
rewriter.replaceOpWithMultiple(op, {inputs});
return success();
}

return mlir::failure();
}
};

} // namespace

namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
patterns.getContext());
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand All @@ -334,9 +424,68 @@ struct XeGPUWgToSgDistributePass
} // namespace

void XeGPUWgToSgDistributePass::runOnOperation() {
// Track existing UnrealizedConversionCastOps
SmallVector<Operation *> existingCastOps;
getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
existingCastOps.push_back(castOp.getOperation());
});

{
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
// VectorType operands. This first converts such operands to
// RankedTensorType, propagates the layout attribute into the encoding
// attribute, and finally converts the RankedTensorType to VectorType based
// on the encoding.

TypeConverter converter;
converter.addConversion([&](Type type) -> Type { return type; });
converter.addConversion(
[&](RankedTensorType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
Type elemTy = type.getElementType();
ArrayRef<int64_t> shape = type.getShape();

int count;
SmallVector<int64_t> subShape;
std::tie(subShape, count) = getSgShapeAndCount(
shape,
dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));

auto newTy = VectorType::get(subShape, elemTy);
result.append(count, newTy);
return success();
});

xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(),
converter);
}

// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
// as well as XeGPU, Arith, and Vector operations.
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
ConversionTarget target(*ctx);
TypeConverter converter;
converter.addConversion([&](Type type) -> Type { return type; });
converter.addConversion(
[&](xegpu::TensorDescType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
Type elemTy = type.getElementType();
ArrayRef<int64_t> shape = type.getShape();

int count;
SmallVector<int64_t> subShape;
xegpu::LayoutAttr layout = type.getLayoutAttr();
std::tie(subShape, count) = getSgShapeAndCount(shape, layout);

if (layout)
layout = layout.dropSgLayoutAndData();

auto newTy = xegpu::TensorDescType::get(
type.getContext(), subShape, elemTy, type.getEncoding(), layout);
result.append(count, newTy);
return success();
});

auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
Expand All @@ -353,26 +502,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
};

auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
return !layout || layout.getSgLayout() == nullptr;
return !layout || !layout.isWgLayout();
};

target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
auto tdescTy = getTensorDescType(op);
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
return isLegal(layout);
});

target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
auto layout = xegpu::getLayoutAttr(op.getResult());
return isLegal(layout);
});

target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
});

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();

// Remove sg_layout and sg_data attributes from the Layout
// attribute for each VectorType result of the operation.
// For Structured Control Flow ops, the layout is simply removed,
// since in 1:N case, the layout for new results are missing.
// Layout propagation pass will activated.
getOperation()->walk([](Operation *op) {
for (OpResult result : op->getOpResults()) {
std::string name = xegpu::getLayoutName(result);
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
op->setAttr(name, layout.dropSgLayoutAndData());
}
}
});
}
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
using namespace mlir;

/// convert ArrayRef<ValueRange> into SmallVector<Value>
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> result;
for (const auto &vals : values)
llvm::append_range(result, vals);
Expand Down Expand Up @@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
auto resultTy = dyn_cast<RankedTensorType>(result.getType());

// Only look at ops casting from VectorType to RankedTensorType
if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
if (!inputTy || !resultTy)
return WalkResult::skip();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
Expand Down Expand Up @@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
}

if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
SmallVector<Value> values = flattenValues(adaptor.getInputs());
SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
auto newOp = rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), outputTy, values);
rewriter.replaceOp(op, newOp);
Expand Down
Loading
Loading