@@ -416,6 +416,15 @@ LogicalResult GPUBarrierConversion::matchAndRewrite(
416
416
return success ();
417
417
}
418
418
419
+ template <typename T>
420
+ Value getDimOp (OpBuilder &builder, MLIRContext *ctx, Location loc,
421
+ gpu::Dimension dimension) {
422
+ Type indexType = IndexType::get (ctx);
423
+ IntegerType i32Type = IntegerType::get (ctx, 32 );
424
+ Value dim = builder.create <T>(loc, indexType, dimension);
425
+ return builder.create <arith::IndexCastOp>(loc, i32Type, dim);
426
+ }
427
+
419
428
// ===----------------------------------------------------------------------===//
420
429
// Shuffle
421
430
// ===----------------------------------------------------------------------===//
@@ -436,8 +445,8 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
436
445
shuffleOp, " shuffle width and target subgroup size mismatch" );
437
446
438
447
Location loc = shuffleOp.getLoc ();
439
- Value trueVal = spirv::ConstantOp::getOne (rewriter.getI1Type (),
440
- shuffleOp.getLoc (), rewriter);
448
+ Value validVal = spirv::ConstantOp::getOne (rewriter.getI1Type (),
449
+ shuffleOp.getLoc (), rewriter);
441
450
auto scope = rewriter.getAttr <spirv::ScopeAttr>(spirv::Scope::Subgroup);
442
451
Value result;
443
452
@@ -450,17 +459,65 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
450
459
result = rewriter.create <spirv::GroupNonUniformShuffleOp>(
451
460
loc, scope, adaptor.getValue (), adaptor.getOffset ());
452
461
break ;
453
- case gpu::ShuffleMode::DOWN:
462
+ case gpu::ShuffleMode::DOWN: {
454
463
result = rewriter.create <spirv::GroupNonUniformShuffleDownOp>(
455
464
loc, scope, adaptor.getValue (), adaptor.getOffset ());
465
+
466
+ MLIRContext *ctx = shuffleOp.getContext ();
467
+ Value dimX =
468
+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
469
+ Value dimY =
470
+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
471
+ Value tidX =
472
+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
473
+ Value tidY =
474
+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
475
+ Value tidZ =
476
+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
477
+ auto i32Type = rewriter.getIntegerType (32 );
478
+ Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
479
+ Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
480
+ Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
481
+ Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
482
+
483
+ Value resultLandId =
484
+ rewriter.create <arith::AddIOp>(loc, landId, adaptor.getOffset ());
485
+ validVal = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
486
+ resultLandId, adaptor.getWidth ());
456
487
break ;
457
- case gpu::ShuffleMode::UP:
488
+ }
489
+ case gpu::ShuffleMode::UP: {
458
490
result = rewriter.create <spirv::GroupNonUniformShuffleUpOp>(
459
491
loc, scope, adaptor.getValue (), adaptor.getOffset ());
492
+
493
+ MLIRContext *ctx = shuffleOp.getContext ();
494
+ Value dimX =
495
+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
496
+ Value dimY =
497
+ getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
498
+ Value tidX =
499
+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
500
+ Value tidY =
501
+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
502
+ Value tidZ =
503
+ getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
504
+ auto i32Type = rewriter.getIntegerType (32 );
505
+ Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
506
+ Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
507
+ Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
508
+ Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
509
+
510
+ Value resultLandId =
511
+ rewriter.create <arith::SubIOp>(loc, landId, adaptor.getOffset ());
512
+ validVal = rewriter.create <arith::CmpIOp>(
513
+ loc, arith::CmpIPredicate::sge, resultLandId,
514
+ rewriter.create <arith::ConstantOp>(
515
+ loc, i32Type, rewriter.getIntegerAttr (i32Type, 0 )));
460
516
break ;
461
517
}
518
+ }
462
519
463
- rewriter.replaceOp (shuffleOp, {result, trueVal });
520
+ rewriter.replaceOp (shuffleOp, {result, validVal });
464
521
return success ();
465
522
}
466
523
0 commit comments