Skip to content

Commit beaffb0

Browse files
committed
[mlir][transform] Decouple GPUDeviceMapping attribute from the GPU transfrom dialect code generator
`DeviceMappingAttrInterface` is implemented as unifiying mechanism for thread mapping. A code generator could use any attribute that implements this interface to lower `scf.foreach_thread` to device specific code. It is allowed to choose its own mapping and interpretation. Currently, GPU transform dialect supports only `GPUThreadMapping` and `GPUBlockMapping`; however, other mappings should to be supported as well. This change addresses this issue. It decouples gpu transform dialect from the `GPUThreadMapping` and `GPUBlockMapping`. Now, they can work any other mapping. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D138020
1 parent 4a77d96 commit beaffb0

File tree

6 files changed

+103
-56
lines changed

6 files changed

+103
-56
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [
2727
}
2828

2929
def GPUThreadMappingAttr
30-
: GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> {
30+
: GPU_Attr<"GPUThreadMapping", "thread", [
31+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
3132
let parameters = (ins
3233
EnumParameter<ThreadsEnum>:$thread
3334
);
@@ -47,7 +48,8 @@ def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
4748
let cppNamespace = "::mlir::gpu";
4849
}
4950

50-
def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] > {
51+
def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
52+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
5153
let parameters = (ins
5254
EnumParameter<BlocksEnum>:$block
5355
);

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ namespace gpu {
4040
/// which case, the union of the number of threads is computed and may result in
4141
/// predication. Dynamic, `scf.foreach_thread` trip counts are currently not
4242
/// supported. Dynamic block dim sizes are currently not supported.
43-
DiagnosedSilenceableFailure
44-
mapNestedForeachToThreadsImpl(RewriterBase &rewriter, Operation *target,
45-
const SmallVectorImpl<int64_t> &blockDim,
46-
bool syncAfterDistribute,
47-
llvm::Optional<TransformOpInterface> transformOp);
43+
DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
44+
RewriterBase &rewriter, Operation *target,
45+
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
46+
llvm::Optional<TransformOpInterface> transformOp,
47+
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
4848

4949
/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
5050
/// one-to-one and the induction variables of `scf.foreach_thread` are rewritten
@@ -56,7 +56,8 @@ DiagnosedSilenceableFailure mapForeachToBlocksImpl(
5656
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
5757
SmallVectorImpl<Value> &)>
5858
blockIdGenerator,
59-
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp);
59+
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
60+
const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
6061

6162
/// Finds the top level scf::ForeachThreadOp of given target.
6263
DiagnosedSilenceableFailure

mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
3434
of the loops it contains to the GPU's parallelism units such as threads and
3535
thread blocks.
3636
}];
37+
38+
let methods = [
39+
InterfaceMethod<[{
40+
Returns mapping as an integer from the attribute.
41+
}],
42+
"int64_t", "getMappingId", (ins)
43+
>
44+
];
3745
}
3846

3947
def DeviceMappingArrayAttr :

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ using namespace mlir::gpu;
3333

3434
#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
3535

36+
//===----------------------------------------------------------------------===//
37+
// GPU Device Mapping Attributes
38+
//===----------------------------------------------------------------------===//
39+
40+
int64_t GPUBlockMappingAttr::getMappingId() const {
41+
return static_cast<int64_t>(getBlock());
42+
}
43+
44+
int64_t GPUThreadMappingAttr::getMappingId() const {
45+
return static_cast<int64_t>(getThread());
46+
}
47+
3648
//===----------------------------------------------------------------------===//
3749
// MMAMatrixType
3850
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1313
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
1414
#include "mlir/Dialect/PDL/IR/PDL.h"
15+
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
1516
#include "mlir/Dialect/SCF/IR/SCF.h"
1617
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1718
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
@@ -33,6 +34,24 @@ class SimpleRewriter : public PatternRewriter {
3334
};
3435
} // namespace
3536

37+
/// Check if given mapping attributes are one of the desired attributes
38+
static DiagnosedSilenceableFailure checkAttributeType(
39+
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes,
40+
const Optional<ArrayAttr> &foreachMapping,
41+
llvm::Optional<TransformOpInterface> transformOp) {
42+
if (!foreachMapping.has_value())
43+
return transformOp->emitSilenceableError() << "mapping must be present";
44+
45+
if (llvm::any_of(foreachMapping->getValue(),
46+
[&](DeviceMappingAttrInterface map) {
47+
return llvm::find(threadMappingAttributes, map) ==
48+
threadMappingAttributes.end();
49+
}))
50+
return transformOp->emitDefiniteFailure()
51+
<< "mapping must be one of " << threadMappingAttributes;
52+
return DiagnosedSilenceableFailure::success();
53+
}
54+
3655
/// Determines if the size of the kernel configuration is supported by the GPU
3756
/// architecture being used. It presently makes use of CUDA limitations, however
3857
/// that aspect may be enhanced for other GPUs.
@@ -157,15 +176,13 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
157176
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
158177
SmallVectorImpl<Value> &)>
159178
blockIdGenerator,
160-
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
179+
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
180+
const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
161181
// Step 0. Target-specific verifications. There is no good place to anchor
162182
// those right now: the ForeachThreadOp is target-independent and the
163183
// transform op does not apply to individual ForeachThreadOp.
164-
MLIRContext *ctx = foreachThreadOp->getContext();
165184
Location loc = foreachThreadOp->getLoc();
166-
Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX);
167-
Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY);
168-
Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ);
185+
169186
if (foreachThreadOp.getNumResults() > 0)
170187
return transformOp.emitSilenceableError()
171188
<< "only bufferized scf.foreach_thread lowers to "
@@ -180,23 +197,15 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
180197
return transformOp.emitSilenceableError()
181198
<< "unsupported dynamic griddim size";
182199
}
183-
if (!foreachThreadOp.getMapping().has_value())
184-
return transformOp.emitSilenceableError() << "mapping must be present";
185200
SmallVector<Attribute> blockMapping =
186201
llvm::to_vector(foreachThreadOp.getMapping()->getValue());
187-
if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
188-
return !map.isa<GPUBlockMappingAttr>();
189-
})) {
190-
return transformOp.emitSilenceableError()
191-
<< "mapping must be #gpu.block<x/y/z/>";
192-
}
193202

194203
// Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
195204
SmallVector<Value> numBlocks =
196205
llvm::to_vector(foreachThreadOp.getNumThreads());
197206
// Ensure we have 3 block sizes, one for each id.
198207
Value one;
199-
for (auto attr : {bX, bY, bZ}) {
208+
for (auto attr : mappingAttributes) {
200209
if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
201210
blockMapping.end()) {
202211
blockMapping.push_back(attr);
@@ -205,10 +214,10 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
205214
}
206215
}
207216

208-
// Step 2. sort the values by the corresponding GPUBlockMappingAttr.
209-
auto comparator = [](Attribute a, Attribute b) -> bool {
210-
return static_cast<int64_t>(a.cast<GPUBlockMappingAttr>().getBlock()) <
211-
static_cast<int64_t>(b.cast<GPUBlockMappingAttr>().getBlock());
217+
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
218+
auto comparator = [&](DeviceMappingAttrInterface a,
219+
DeviceMappingAttrInterface b) -> bool {
220+
return a.getMappingId() < b.getMappingId();
212221
};
213222
SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
214223
blockMapping, numBlocks, comparator);
@@ -222,8 +231,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
222231
BlockAndValueMapping bvm;
223232
for (auto [blockIdx, blockDim] :
224233
llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
225-
bvm.map(blockIdx, blockOps[static_cast<int64_t>(
226-
blockDim.cast<GPUBlockMappingAttr>().getBlock())]);
234+
bvm.map(blockIdx,
235+
blockOps[static_cast<int64_t>(
236+
blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
227237
}
228238

229239
// Step 4. Move the body of foreachThreadOp.
@@ -331,9 +341,17 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
331341
}
332342

333343
SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
334-
diag = mlir::transform::gpu::mapForeachToBlocksImpl(
335-
rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
336-
transformOp);
344+
SmallVector<DeviceMappingAttrInterface> blockMappingAttributes = {
345+
GPUBlockMappingAttr::get(getContext(), Blocks::DimX),
346+
GPUBlockMappingAttr::get(getContext(), Blocks::DimY),
347+
GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)};
348+
349+
diag = checkAttributeType(blockMappingAttributes,
350+
topLevelForeachThreadOp.getMapping(), transformOp);
351+
if (diag.succeeded())
352+
diag = mlir::transform::gpu::mapForeachToBlocksImpl(
353+
rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
354+
transformOp, blockMappingAttributes);
337355
if (diag.succeeded()) {
338356
diag = alterGpuLaunch(rewriter, gpuLaunch,
339357
cast<TransformOpInterface>(getOperation()),
@@ -358,7 +376,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
358376
static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
359377
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
360378
const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
361-
llvm::Optional<TransformOpInterface> transformOp) {
379+
llvm::Optional<TransformOpInterface> transformOp,
380+
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
362381
// Step 0. Target-specific verifications. There is no good place to anchor
363382
// those right now: the ForeachThreadOp is target-independent and the
364383
// transform op does not apply to individual ForeachThreadOp.
@@ -369,11 +388,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
369388
}
370389
return emitDefiniteFailure(foreachThreadOp, message);
371390
};
372-
MLIRContext *ctx = foreachThreadOp->getContext();
373391
Location loc = foreachThreadOp->getLoc();
374-
Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX);
375-
Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY);
376-
Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ);
377392
if (foreachThreadOp.getNumResults() > 0)
378393
return failureHelper(
379394
"only bufferized scf.foreach_thread lowers to gpu.thread_id");
@@ -389,20 +404,14 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
389404
return failureHelper("mapping must be present");
390405
SmallVector<Attribute> threadMapping =
391406
llvm::to_vector(foreachThreadOp.getMapping()->getValue());
392-
if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) {
393-
return !map.isa<GPUThreadMappingAttr>();
394-
})) {
395-
return transformOp->emitSilenceableError()
396-
<< "mapping must be #gpu.thread<x/y/z/>";
397-
}
398407

399408
// Step 1. Complete the threadMapping to a full mapping (with 1s) if
400409
// necessary.
401410
SmallVector<Value> numThreads =
402411
llvm::to_vector(foreachThreadOp.getNumThreads());
403412
// Ensure we have 3 block sizes, one for each id.
404413
Value one;
405-
for (auto attr : {tX, tY, tZ}) {
414+
for (auto attr : threadMappingAttributes) {
406415
if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
407416
threadMapping.end()) {
408417
threadMapping.push_back(attr);
@@ -411,10 +420,10 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
411420
}
412421
}
413422

414-
// Step 2. sort the values by the corresponding GPUThreadMappingAttr.
415-
auto comparator = [](Attribute a, Attribute b) -> bool {
416-
return static_cast<int64_t>(a.cast<GPUThreadMappingAttr>().getThread()) <
417-
static_cast<int64_t>(b.cast<GPUThreadMappingAttr>().getThread());
423+
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
424+
auto comparator = [&](DeviceMappingAttrInterface a,
425+
DeviceMappingAttrInterface b) -> bool {
426+
return a.getMappingId() < b.getMappingId();
418427
};
419428
SmallVector<Value> blockDimValues =
420429
scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
@@ -434,8 +443,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
434443
BlockAndValueMapping bvm;
435444
for (auto [blockIdx, blockDim] :
436445
llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
437-
bvm.map(blockIdx, threadOps[static_cast<int64_t>(
438-
blockDim.cast<GPUThreadMappingAttr>().getThread())]);
446+
bvm.map(
447+
blockIdx,
448+
threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
439449
}
440450

441451
// Step 4. Maybe create conditionals to predicate the region.
@@ -501,12 +511,18 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
501511
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
502512
RewriterBase &rewriter, Operation *target,
503513
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
504-
llvm::Optional<TransformOpInterface> transformOp) {
514+
llvm::Optional<TransformOpInterface> transformOp,
515+
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
505516
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
506517
target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
507-
rewriter.setInsertionPoint(foreachThreadOp);
508-
diag = rewriteOneForeachThreadToGpuThreads(
509-
rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp);
518+
diag = checkAttributeType(threadMappingAttributes,
519+
foreachThreadOp.getMapping(), transformOp);
520+
if (diag.succeeded()) {
521+
rewriter.setInsertionPoint(foreachThreadOp);
522+
diag = rewriteOneForeachThreadToGpuThreads(
523+
rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
524+
threadMappingAttributes);
525+
}
510526
return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
511527
});
512528
return diag;
@@ -536,11 +552,19 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
536552
return diag;
537553
}
538554

539-
SimpleRewriter rewriter(getContext());
555+
MLIRContext *ctx = getContext();
556+
SimpleRewriter rewriter(ctx);
540557
rewriter.setInsertionPoint(target);
541558

559+
SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
560+
GPUThreadMappingAttr::get(ctx, Threads::DimX),
561+
GPUThreadMappingAttr::get(ctx, Threads::DimY),
562+
GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
563+
542564
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
543-
rewriter, target, blockDim, getSyncAfterDistribute(), transformOp);
565+
rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
566+
threadMappingAttributes);
567+
544568
if (diag.succeeded()) {
545569
diag =
546570
alterGpuLaunch(rewriter, gpuLaunch, transformOp, llvm::None, llvm::None,

mlir/test/Dialect/GPU/transform-gpu-failing.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func.func @map_nested_foreach_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: t
160160
transform.sequence failures(propagate) {
161161
^bb1(%arg0: !pdl.operation):
162162
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0
163-
%foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30]
163+
%foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
164164
%funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
165165
// expected-error @below {{only bufferized scf.foreach_thread lowers to gpu.thread_id}}
166166
transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] }

0 commit comments

Comments
 (0)