12
12
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
13
13
#include " mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
14
14
#include " mlir/Dialect/PDL/IR/PDL.h"
15
+ #include " mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
15
16
#include " mlir/Dialect/SCF/IR/SCF.h"
16
17
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
17
18
#include " mlir/Dialect/Transform/IR/TransformInterfaces.h"
@@ -33,6 +34,24 @@ class SimpleRewriter : public PatternRewriter {
33
34
};
34
35
} // namespace
35
36
37
+ // / Check if given mapping attributes are one of the desired attributes
38
+ static DiagnosedSilenceableFailure checkAttributeType (
39
+ const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes,
40
+ const Optional<ArrayAttr> &foreachMapping,
41
+ llvm::Optional<TransformOpInterface> transformOp) {
42
+ if (!foreachMapping.has_value ())
43
+ return transformOp->emitSilenceableError () << " mapping must be present" ;
44
+
45
+ if (llvm::any_of (foreachMapping->getValue (),
46
+ [&](DeviceMappingAttrInterface map) {
47
+ return llvm::find (threadMappingAttributes, map) ==
48
+ threadMappingAttributes.end ();
49
+ }))
50
+ return transformOp->emitDefiniteFailure ()
51
+ << " mapping must be one of " << threadMappingAttributes;
52
+ return DiagnosedSilenceableFailure::success ();
53
+ }
54
+
36
55
// / Determines if the size of the kernel configuration is supported by the GPU
37
56
// / architecture being used. It presently makes use of CUDA limitations, however
38
57
// / that aspect may be enhanced for other GPUs.
@@ -157,15 +176,13 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
157
176
function_ref<void (RewriterBase &, scf::ForeachThreadOp,
158
177
SmallVectorImpl<Value> &)>
159
178
blockIdGenerator,
160
- SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
179
+ SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
180
+ const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
161
181
// Step 0. Target-specific verifications. There is no good place to anchor
162
182
// those right now: the ForeachThreadOp is target-independent and the
163
183
// transform op does not apply to individual ForeachThreadOp.
164
- MLIRContext *ctx = foreachThreadOp->getContext ();
165
184
Location loc = foreachThreadOp->getLoc ();
166
- Attribute bX = GPUBlockMappingAttr::get (ctx, Blocks::DimX);
167
- Attribute bY = GPUBlockMappingAttr::get (ctx, Blocks::DimY);
168
- Attribute bZ = GPUBlockMappingAttr::get (ctx, Blocks::DimZ);
185
+
169
186
if (foreachThreadOp.getNumResults () > 0 )
170
187
return transformOp.emitSilenceableError ()
171
188
<< " only bufferized scf.foreach_thread lowers to "
@@ -180,23 +197,15 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
180
197
return transformOp.emitSilenceableError ()
181
198
<< " unsupported dynamic griddim size" ;
182
199
}
183
- if (!foreachThreadOp.getMapping ().has_value ())
184
- return transformOp.emitSilenceableError () << " mapping must be present" ;
185
200
SmallVector<Attribute> blockMapping =
186
201
llvm::to_vector (foreachThreadOp.getMapping ()->getValue ());
187
- if (llvm::any_of (blockMapping, [](DeviceMappingAttrInterface map) {
188
- return !map.isa <GPUBlockMappingAttr>();
189
- })) {
190
- return transformOp.emitSilenceableError ()
191
- << " mapping must be #gpu.block<x/y/z/>" ;
192
- }
193
202
194
203
// Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
195
204
SmallVector<Value> numBlocks =
196
205
llvm::to_vector (foreachThreadOp.getNumThreads ());
197
206
// Ensure we have 3 block sizes, one for each id.
198
207
Value one;
199
- for (auto attr : {bX, bY, bZ} ) {
208
+ for (auto attr : mappingAttributes ) {
200
209
if (std::find (blockMapping.begin (), blockMapping.end (), attr) ==
201
210
blockMapping.end ()) {
202
211
blockMapping.push_back (attr);
@@ -205,10 +214,10 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
205
214
}
206
215
}
207
216
208
- // Step 2. sort the values by the corresponding GPUBlockMappingAttr .
209
- auto comparator = [](Attribute a, Attribute b) -> bool {
210
- return static_cast < int64_t >(a. cast <GPUBlockMappingAttr>(). getBlock ()) <
211
- static_cast < int64_t >(b. cast <GPUBlockMappingAttr>(). getBlock () );
217
+ // Step 2. sort the values by the corresponding DeviceMappingAttrInterface .
218
+ auto comparator = [&](DeviceMappingAttrInterface a,
219
+ DeviceMappingAttrInterface b) -> bool {
220
+ return a. getMappingId () < b. getMappingId ( );
212
221
};
213
222
SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey (
214
223
blockMapping, numBlocks, comparator);
@@ -222,8 +231,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
222
231
BlockAndValueMapping bvm;
223
232
for (auto [blockIdx, blockDim] :
224
233
llvm::zip (foreachThreadOp.getThreadIndices (), blockMapping)) {
225
- bvm.map (blockIdx, blockOps[static_cast <int64_t >(
226
- blockDim.cast <GPUBlockMappingAttr>().getBlock ())]);
234
+ bvm.map (blockIdx,
235
+ blockOps[static_cast <int64_t >(
236
+ blockDim.cast <DeviceMappingAttrInterface>().getMappingId ())]);
227
237
}
228
238
229
239
// Step 4. Move the body of foreachThreadOp.
@@ -331,9 +341,17 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
331
341
}
332
342
333
343
SmallVector<int64_t > gridDim = extractFromI64ArrayAttr (getGridDim ());
334
- diag = mlir::transform::gpu::mapForeachToBlocksImpl (
335
- rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
336
- transformOp);
344
+ SmallVector<DeviceMappingAttrInterface> blockMappingAttributes = {
345
+ GPUBlockMappingAttr::get (getContext (), Blocks::DimX),
346
+ GPUBlockMappingAttr::get (getContext (), Blocks::DimY),
347
+ GPUBlockMappingAttr::get (getContext (), Blocks::DimZ)};
348
+
349
+ diag = checkAttributeType (blockMappingAttributes,
350
+ topLevelForeachThreadOp.getMapping (), transformOp);
351
+ if (diag.succeeded ())
352
+ diag = mlir::transform::gpu::mapForeachToBlocksImpl (
353
+ rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
354
+ transformOp, blockMappingAttributes);
337
355
if (diag.succeeded ()) {
338
356
diag = alterGpuLaunch (rewriter, gpuLaunch,
339
357
cast<TransformOpInterface>(getOperation ()),
@@ -358,7 +376,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
358
376
static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads (
359
377
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
360
378
const SmallVectorImpl<int64_t > &globalBlockDims, bool syncAfterDistribute,
361
- llvm::Optional<TransformOpInterface> transformOp) {
379
+ llvm::Optional<TransformOpInterface> transformOp,
380
+ const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
362
381
// Step 0. Target-specific verifications. There is no good place to anchor
363
382
// those right now: the ForeachThreadOp is target-independent and the
364
383
// transform op does not apply to individual ForeachThreadOp.
@@ -369,11 +388,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
369
388
}
370
389
return emitDefiniteFailure (foreachThreadOp, message);
371
390
};
372
- MLIRContext *ctx = foreachThreadOp->getContext ();
373
391
Location loc = foreachThreadOp->getLoc ();
374
- Attribute tX = GPUThreadMappingAttr::get (ctx, Threads::DimX);
375
- Attribute tY = GPUThreadMappingAttr::get (ctx, Threads::DimY);
376
- Attribute tZ = GPUThreadMappingAttr::get (ctx, Threads::DimZ);
377
392
if (foreachThreadOp.getNumResults () > 0 )
378
393
return failureHelper (
379
394
" only bufferized scf.foreach_thread lowers to gpu.thread_id" );
@@ -389,20 +404,14 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
389
404
return failureHelper (" mapping must be present" );
390
405
SmallVector<Attribute> threadMapping =
391
406
llvm::to_vector (foreachThreadOp.getMapping ()->getValue ());
392
- if (llvm::any_of (threadMapping, [](DeviceMappingAttrInterface map) {
393
- return !map.isa <GPUThreadMappingAttr>();
394
- })) {
395
- return transformOp->emitSilenceableError ()
396
- << " mapping must be #gpu.thread<x/y/z/>" ;
397
- }
398
407
399
408
// Step 1. Complete the threadMapping to a full mapping (with 1s) if
400
409
// necessary.
401
410
SmallVector<Value> numThreads =
402
411
llvm::to_vector (foreachThreadOp.getNumThreads ());
403
412
// Ensure we have 3 block sizes, one for each id.
404
413
Value one;
405
- for (auto attr : {tX, tY, tZ} ) {
414
+ for (auto attr : threadMappingAttributes ) {
406
415
if (std::find (threadMapping.begin (), threadMapping.end (), attr) ==
407
416
threadMapping.end ()) {
408
417
threadMapping.push_back (attr);
@@ -411,10 +420,10 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
411
420
}
412
421
}
413
422
414
- // Step 2. sort the values by the corresponding GPUThreadMappingAttr .
415
- auto comparator = [](Attribute a, Attribute b) -> bool {
416
- return static_cast < int64_t >(a. cast <GPUThreadMappingAttr>(). getThread ()) <
417
- static_cast < int64_t >(b. cast <GPUThreadMappingAttr>(). getThread () );
423
+ // Step 2. sort the values by the corresponding DeviceMappingAttrInterface .
424
+ auto comparator = [&](DeviceMappingAttrInterface a,
425
+ DeviceMappingAttrInterface b) -> bool {
426
+ return a. getMappingId () < b. getMappingId ( );
418
427
};
419
428
SmallVector<Value> blockDimValues =
420
429
scf::ForeachThreadOp::getValuesSortedByKey (threadMapping, numThreads,
@@ -434,8 +443,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
434
443
BlockAndValueMapping bvm;
435
444
for (auto [blockIdx, blockDim] :
436
445
llvm::zip (foreachThreadOp.getThreadIndices (), threadMapping)) {
437
- bvm.map (blockIdx, threadOps[static_cast <int64_t >(
438
- blockDim.cast <GPUThreadMappingAttr>().getThread ())]);
446
+ bvm.map (
447
+ blockIdx,
448
+ threadOps[blockDim.cast <DeviceMappingAttrInterface>().getMappingId ()]);
439
449
}
440
450
441
451
// Step 4. Maybe create conditionals to predicate the region.
@@ -501,12 +511,18 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
501
511
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl (
502
512
RewriterBase &rewriter, Operation *target,
503
513
const SmallVectorImpl<int64_t > &blockDim, bool syncAfterDistribute,
504
- llvm::Optional<TransformOpInterface> transformOp) {
514
+ llvm::Optional<TransformOpInterface> transformOp,
515
+ const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
505
516
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success ();
506
517
target->walk ([&](scf::ForeachThreadOp foreachThreadOp) {
507
- rewriter.setInsertionPoint (foreachThreadOp);
508
- diag = rewriteOneForeachThreadToGpuThreads (
509
- rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp);
518
+ diag = checkAttributeType (threadMappingAttributes,
519
+ foreachThreadOp.getMapping (), transformOp);
520
+ if (diag.succeeded ()) {
521
+ rewriter.setInsertionPoint (foreachThreadOp);
522
+ diag = rewriteOneForeachThreadToGpuThreads (
523
+ rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
524
+ threadMappingAttributes);
525
+ }
510
526
return diag.succeeded () ? WalkResult::advance () : WalkResult::interrupt ();
511
527
});
512
528
return diag;
@@ -536,11 +552,19 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
536
552
return diag;
537
553
}
538
554
539
- SimpleRewriter rewriter (getContext ());
555
+ MLIRContext *ctx = getContext ();
556
+ SimpleRewriter rewriter (ctx);
540
557
rewriter.setInsertionPoint (target);
541
558
559
+ SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
560
+ GPUThreadMappingAttr::get (ctx, Threads::DimX),
561
+ GPUThreadMappingAttr::get (ctx, Threads::DimY),
562
+ GPUThreadMappingAttr::get (ctx, Threads::DimZ)};
563
+
542
564
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl (
543
- rewriter, target, blockDim, getSyncAfterDistribute (), transformOp);
565
+ rewriter, target, blockDim, getSyncAfterDistribute (), transformOp,
566
+ threadMappingAttributes);
567
+
544
568
if (diag.succeeded ()) {
545
569
diag =
546
570
alterGpuLaunch (rewriter, gpuLaunch, transformOp, llvm::None, llvm::None,
0 commit comments