Skip to content

Commit 5578bcb

Browse files
authored
[mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution (#142618)
This PR introduces support for `scf::ForOp`, `scf::WhileOp`, `scf::If`, and `scf::Condition` within the workgroup-subgroup-distribution pass, leveraging the `SCFStructuralTypeConversionsAndLegality`.
1 parent 51689c9 commit 5578bcb

File tree

5 files changed

+430
-31
lines changed

5 files changed

+430
-31
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class TensorDescType;
2626

2727
namespace xegpu {
2828

29+
/// Flatten a set of ValueRange into a single SmallVector<Value>
30+
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);
31+
2932
/// If tensor descriptor has a layout attribute it is used in SIMT mode.
3033
/// In this mode, the distributed vector shape is determined as follows:
3134
/// Definitions:

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 196 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1414
#include "mlir/Dialect/Index/IR/IndexOps.h"
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1617
#include "mlir/Dialect/Utils/IndexingUtils.h"
1718
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1819
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
20+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1921
#include "mlir/Transforms/DialectConversion.h"
2022

2123
namespace mlir {
@@ -29,6 +31,29 @@ using namespace mlir;
2931

3032
namespace {
3133

34+
static std::pair<SmallVector<int64_t>, int>
35+
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
36+
int count = 1;
37+
SmallVector<int64_t> sgShape(shape);
38+
39+
if (layout && layout.isWgLayout()) {
40+
DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
41+
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
42+
if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
43+
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
44+
else
45+
sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
46+
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
47+
// Clamp distUnit to the original shape to handle cases where data is
48+
// shared among subgroups, which may cause distUnit to exceed the original
49+
// shape.
50+
for (size_t i = 0; i < distUnit.size(); ++i)
51+
distUnit[i] = std::min(shape[i], distUnit[i]);
52+
count = computeProduct(shape) / computeProduct(distUnit);
53+
}
54+
return std::make_pair(sgShape, count);
55+
}
56+
3257
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
3358
/// from a workgroup descriptor. It replaces the offsets and sizes with
3459
/// appropriate values for the subgroup.
@@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
129154
return rewriter.notifyMatchFailure(
130155
op, "sgLayout attribute is required in layout");
131156

132-
SmallVector<int64_t> sgShape;
133-
if (auto sgDataAttr = layout.getSgData()) {
134-
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
135-
} else {
136-
assert(wgShape.size() == sgLayout.size() &&
137-
"sgLayout and wgShape must have the same rank");
138-
sgShape.reserve(wgShape.size());
139-
for (size_t i = 0; i < wgShape.size(); ++i) {
140-
assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
141-
sgShape.push_back(wgShape[i] / sgLayout[i]);
142-
}
143-
}
157+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
144158

145159
// TODO : Handle order attribute
146160
// Get the subgroup ID
@@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
266280
if (resultTy.getRank() != 2)
267281
return failure();
268282

269-
auto originalLayout =
270-
llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
283+
auto originalLayout = xegpu::getLayoutAttr(op.getResult());
271284
if (!originalLayout)
272285
return failure();
273286

274-
SmallVector<Value> newDpasOps;
275287
size_t i = 0;
288+
SmallVector<Value> newDpasOps;
276289
for (auto aVec : adaptor.getLhs()) {
277290
for (auto bVec : adaptor.getRhs()) {
291+
278292
llvm::SmallVector<Value> operands({aVec, bVec});
279293
Value tmpC;
280294
if (op.getAcc()) {
@@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
288302
llvm::cast<VectorType>(bVec.getType()).getShape();
289303
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
290304
resultTy.getElementType());
291-
tmpC = rewriter.create<xegpu::DpasOp>(
292-
loc, resTy, operands,
293-
llvm::ArrayRef<NamedAttribute>(
294-
{"layout_result_0", originalLayout.dropSgLayoutAndData()}));
305+
tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
306+
xegpu::setLayoutAttr(cast<OpResult>(tmpC),
307+
originalLayout.dropSgLayoutAndData());
308+
295309
newDpasOps.push_back(tmpC);
296310
}
297311
}
@@ -314,14 +328,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314328
}
315329
};
316330

331+
// Handles UnrealizedConversionCastOp generated during
332+
// SCFStructuralTypeConversions (step 1). This op may appear as either a
333+
// target or source materialization for Vector values, e.g.:
334+
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
335+
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
336+
// it could be either 1:N or N:1 cast. In both cases, the pattern
337+
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
338+
// for example, the following scf::forOp
339+
// ```
340+
// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
341+
// %n = use(%arg1): vector<128x128xf16>
342+
// scf.yield %n : vector<128x128xf16>
343+
// }
344+
// ```
345+
// Could be converted to:
346+
// ```
347+
// %1 = unrealized_conversion_cast %0
348+
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
349+
// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
350+
// -> (vector<16x16xf16>, vector<16x16xf16) {
351+
// %m = unrealized_conversion_cast %arg1, %arg2
352+
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
353+
// %n = use(%m): vector<128x128xf16>
354+
// %b = unrealized_conversion_cast %n
355+
// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
356+
// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
357+
// }
358+
// %cast = unrealized_conversion_cast %for:2
359+
// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
360+
// ```
361+
// TODO: remove it when context-aware type converter is ready.
362+
struct UnrealizedConversionCastOpPattern
363+
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
364+
using OpConversionPattern<
365+
mlir::UnrealizedConversionCastOp>::OpConversionPattern;
366+
367+
mlir::LogicalResult
368+
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
369+
ConversionPatternRewriter &rewriter) const override {
370+
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
371+
372+
auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
373+
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
374+
375+
if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
376+
!llvm::all_equal(ValueRange(inputs).getTypes()))
377+
return failure();
378+
379+
// Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
380+
// It is generated by source materialization (e.g., inits to scf forOp).
381+
// The input values provided by the adaptor should already be distributed,
382+
// and their types should correspond exactly to the result types of the
383+
// operation.
384+
if (op.getNumOperands() == 1 &&
385+
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
386+
rewriter.replaceOp(op, inputs);
387+
return success();
388+
}
389+
390+
// Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
391+
// It is generated by target materialization (e.g., arguments/results
392+
// of scf forOp). All input values must have the same vector type, and
393+
// their shape must be evenly divisible by the output vector's shape
394+
// (determined by the nature of the workgroup to subgroup distribution).
395+
// TODO: it is not safe to do such forward, since such N:1 cast could be
396+
// from others.
397+
if (op.getNumResults() == 1 &&
398+
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
399+
rewriter.replaceOpWithMultiple(op, {inputs});
400+
return success();
401+
}
402+
403+
return mlir::failure();
404+
}
405+
};
406+
317407
} // namespace
318408

319409
namespace mlir {
320410
namespace xegpu {
321411
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
322412
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323-
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324-
patterns.getContext());
413+
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
414+
UnrealizedConversionCastOpPattern>(patterns.getContext());
325415
}
326416
} // namespace xegpu
327417
} // namespace mlir
@@ -334,9 +424,68 @@ struct XeGPUWgToSgDistributePass
334424
} // namespace
335425

336426
void XeGPUWgToSgDistributePass::runOnOperation() {
427+
// Track existing UnrealizedConversionCastOps
428+
SmallVector<Operation *> existingCastOps;
429+
getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
430+
existingCastOps.push_back(castOp.getOperation());
431+
});
432+
433+
{
434+
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
435+
// VectorType operands. This first converts such operands to
436+
// RankedTensorType, propagates the layout attribute into the encoding
437+
// attribute, and finally converts the RankedTensorType to VectorType based
438+
// on the encoding.
439+
440+
TypeConverter converter;
441+
converter.addConversion([&](Type type) -> Type { return type; });
442+
converter.addConversion(
443+
[&](RankedTensorType type,
444+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
445+
Type elemTy = type.getElementType();
446+
ArrayRef<int64_t> shape = type.getShape();
447+
448+
int count;
449+
SmallVector<int64_t> subShape;
450+
std::tie(subShape, count) = getSgShapeAndCount(
451+
shape,
452+
dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
453+
454+
auto newTy = VectorType::get(subShape, elemTy);
455+
result.append(count, newTy);
456+
return success();
457+
});
458+
459+
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(),
460+
converter);
461+
}
462+
463+
// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
464+
// as well as XeGPU, Arith, and Vector operations.
337465
MLIRContext *ctx = &getContext();
338466
RewritePatternSet patterns(ctx);
339467
ConversionTarget target(*ctx);
468+
TypeConverter converter;
469+
converter.addConversion([&](Type type) -> Type { return type; });
470+
converter.addConversion(
471+
[&](xegpu::TensorDescType type,
472+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
473+
Type elemTy = type.getElementType();
474+
ArrayRef<int64_t> shape = type.getShape();
475+
476+
int count;
477+
SmallVector<int64_t> subShape;
478+
xegpu::LayoutAttr layout = type.getLayoutAttr();
479+
std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
480+
481+
if (layout)
482+
layout = layout.dropSgLayoutAndData();
483+
484+
auto newTy = xegpu::TensorDescType::get(
485+
type.getContext(), subShape, elemTy, type.getEncoding(), layout);
486+
result.append(count, newTy);
487+
return success();
488+
});
340489

341490
auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
342491
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
@@ -353,26 +502,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
353502
};
354503

355504
auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
356-
return !layout || layout.getSgLayout() == nullptr;
505+
return !layout || !layout.isWgLayout();
357506
};
358507

359508
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
360509
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
361510
xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
362511
auto tdescTy = getTensorDescType(op);
363-
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
512+
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
364513
return isLegal(layout);
365514
});
366515

367516
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
368-
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
517+
auto layout = xegpu::getLayoutAttr(op.getResult());
369518
return isLegal(layout);
370519
});
371520

521+
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
522+
[=](UnrealizedConversionCastOp op) {
523+
return llvm::is_contained(existingCastOps, op.getOperation());
524+
});
525+
372526
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
373527

528+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
529+
target);
374530
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
375531
if (failed(
376532
applyPartialConversion(getOperation(), target, std::move(patterns))))
377533
return signalPassFailure();
534+
535+
// Remove sg_layout and sg_data attributes from the Layout
536+
// attribute for each VectorType result of the operation.
537+
// For Structured Control Flow ops, the layout is simply removed,
538+
// since in 1:N case, the layout for new results are missing.
539+
// Layout propagation pass will activated.
540+
getOperation()->walk([](Operation *op) {
541+
for (OpResult result : op->getOpResults()) {
542+
std::string name = xegpu::getLayoutName(result);
543+
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
544+
op->removeAttr(name);
545+
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
546+
op->setAttr(name, layout.dropSgLayoutAndData());
547+
}
548+
}
549+
});
378550
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
using namespace mlir;
2828

2929
/// convert ArrayRef<ValueRange> into SmallVector<Value>
30-
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
30+
SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
3131
SmallVector<Value> result;
3232
for (const auto &vals : values)
3333
llvm::append_range(result, vals);
@@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
271271
auto resultTy = dyn_cast<RankedTensorType>(result.getType());
272272

273273
// Only look at ops casting from VectorType to RankedTensorType
274-
if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
274+
if (!inputTy || !resultTy)
275275
return WalkResult::skip();
276276

277277
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
@@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
342342
}
343343

344344
if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
345-
SmallVector<Value> values = flattenValues(adaptor.getInputs());
345+
SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
346346
auto newOp = rewriter.create<UnrealizedConversionCastOp>(
347347
op.getLoc(), outputTy, values);
348348
rewriter.replaceOp(op, newOp);

0 commit comments

Comments
 (0)