@@ -502,8 +502,11 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
502
502
503
503
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl (
504
504
RewriterBase &rewriter, Operation *target,
505
- const SmallVectorImpl<int64_t > &blockDim, bool syncAfterDistribute,
506
- std::optional<TransformOpInterface> transformOp,
505
+ const SmallVectorImpl<int64_t > &blockDim,
506
+ function_ref<void (RewriterBase &, scf::ForeachThreadOp,
507
+ SmallVectorImpl<Value> &)>
508
+ threadIdGenerator,
509
+ bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
507
510
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
508
511
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success ();
509
512
target->walk ([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -517,14 +520,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
517
520
foreachThreadOp.getMapping (), transformOp);
518
521
if (diag.succeeded ()) {
519
522
rewriter.setInsertionPoint (foreachThreadOp);
520
- IndexType indexType = rewriter.getIndexType ();
521
- SmallVector<Value> threadOps{
522
- rewriter.create <ThreadIdOp>(foreachThreadOp.getLoc (), indexType,
523
- Dimension::x),
524
- rewriter.create <ThreadIdOp>(foreachThreadOp.getLoc (), indexType,
525
- Dimension::y),
526
- rewriter.create <ThreadIdOp>(foreachThreadOp.getLoc (), indexType,
527
- Dimension::z)};
523
+ SmallVector<Value> threadOps;
524
+ threadIdGenerator (rewriter, foreachThreadOp, threadOps);
528
525
diag = rewriteOneForeachThreadToGpuThreads (
529
526
rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
530
527
transformOp, threadMappingAttributes);
@@ -562,10 +559,20 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
562
559
GPUThreadMappingAttr::get (ctx, Threads::DimX),
563
560
GPUThreadMappingAttr::get (ctx, Threads::DimY),
564
561
GPUThreadMappingAttr::get (ctx, Threads::DimZ)};
565
-
562
+ auto threadIdGenerator = [](RewriterBase &rewriter,
563
+ scf::ForeachThreadOp foreachThreadOp,
564
+ SmallVectorImpl<Value> &threadIds) {
565
+ IndexType indexType = rewriter.getIndexType ();
566
+ threadIds.assign ({rewriter.create <ThreadIdOp>(foreachThreadOp->getLoc (),
567
+ indexType, Dimension::x),
568
+ rewriter.create <ThreadIdOp>(foreachThreadOp->getLoc (),
569
+ indexType, Dimension::y),
570
+ rewriter.create <ThreadIdOp>(foreachThreadOp->getLoc (),
571
+ indexType, Dimension::z)});
572
+ };
566
573
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl (
567
- rewriter, target, blockDim, getSyncAfterDistribute (), transformOp ,
568
- threadMappingAttributes);
574
+ rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute (),
575
+ transformOp, threadMappingAttributes);
569
576
570
577
if (diag.succeeded ()) {
571
578
diag = alterGpuLaunch (rewriter, gpuLaunch, transformOp, std::nullopt,
0 commit comments