Skip to content

Commit efd0a7f

Browse files
authored
[mlir][ROCDL][~NFC] Migrate to LLVM dialect default builders (#125609)
There were a bunch of spots in ROCDL.td where we were defining our own llvmBuilder call which could have been generated using the default built-in one on LLVM_IntrOpBase. This commit cleans up such usages in the interests of potentinally enabling ROCDL import in the future and of making best practices more obvious. The one breaking change is renaming WaitcntOp to SWaitcntOp, which should have minimal impact.
1 parent 5812d0b commit efd0a7f

File tree

5 files changed

+60
-151
lines changed

5 files changed

+60
-151
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 41 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
100100
overloadedOperands, traits, numResults, requiresAccessGroup,
101101
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
102102

103+
// Subclass to save typing and ease readibility when there aren't overloaded
104+
// operands or memory accesses.
105+
class ROCDL_ConcreteNonMemIntrOp<string mnemonic, list<Trait> traits,
106+
int numResults, list<int> immArgPositions = [],
107+
list<string> immArgNames = []>
108+
: ROCDL_IntrOp<mnemonic, [], [], traits, numResults, 0, 0,
109+
immArgPositions, immArgNames>;
103110
//===----------------------------------------------------------------------===//
104111
// ROCDL special register op definitions
105112
//===----------------------------------------------------------------------===//
@@ -150,37 +157,26 @@ class ROCDL_MbcntOp<string mnemonic> :
150157
def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
151158
def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;
152159

153-
def ROCDL_DsSwizzleOp :
154-
ROCDL_Op<"ds_swizzle">,
155-
Results<(outs I32:$res)>,
156-
Arguments<(ins I32:$src,
157-
I32:$offset)>
158-
{
159-
string llvmBuilder = [{
160-
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_swizzle, {$src, $offset});
161-
}];
160+
def ROCDL_DsSwizzleOp : ROCDL_ConcreteNonMemIntrOp<"ds_swizzle", [], 1>,
161+
Arguments<(ins I32:$src,
162+
I32:$offset)> {
163+
let results = (outs I32:$res);
162164
let assemblyFormat = [{
163165
$src `,` $offset attr-dict `:` `(` type($src) `,` type($offset) `)` `->` type($res)
164166
}];
165167
}
166168

167-
def ROCDL_DsBpermuteOp :
168-
ROCDL_Op<"ds_bpermute">,
169-
Results<(outs I32:$res)>,
170-
Arguments<(ins I32:$index,
171-
I32:$src)>
172-
{
173-
string llvmBuilder = [{
174-
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_bpermute, {$index, $src});
175-
}];
169+
def ROCDL_DsBpermuteOp : ROCDL_ConcreteNonMemIntrOp<"ds_bpermute", [], 1>,
170+
Arguments<(ins I32:$index,
171+
I32:$src)> {
172+
let results = (outs I32:$res);
176173
let assemblyFormat = [{
177174
$index `,` $src attr-dict `:` `(` type($index) `,` type($src) `)` `->` type($res)
178175
}];
179176
}
180177

181178
def ROCDL_BallotOp :
182-
ROCDL_Op<"ballot">,
183-
Results<(outs LLVM_Type:$res)>,
179+
ROCDL_IntrOp<"ballot", [0], [], [], 1>,
184180
Arguments<(ins I1:$pred)> {
185181
let summary = "Vote across thread group";
186182

@@ -189,11 +185,6 @@ def ROCDL_BallotOp :
189185
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
190186
}];
191187

192-
string llvmBuilder = [{
193-
$res = createIntrinsicCall(builder,
194-
llvm::Intrinsic::amdgcn_ballot, {$pred}, {$_resultType});
195-
}];
196-
197188
let assemblyFormat = "$pred attr-dict `:` type($res)";
198189
}
199190

@@ -249,18 +240,12 @@ def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z",
249240

250241
// Emits the waintcnt instruction. The bitfield's semantics depend
251242
// on the target chipset
252-
def ROCDL_WaitcntOp : ROCDL_Op<"waitcnt">, Arguments<(ins I32Attr:$bitfield)> {
253-
string llvmBuilder = [{
254-
createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_waitcnt,
255-
{builder.getInt32($bitfield)});
256-
}];
243+
def ROCDL_SWaitcntOp : ROCDL_ConcreteNonMemIntrOp<"s.waitcnt", [], 0, [0], ["bitfield"]>,
244+
Arguments<(ins I32Attr:$bitfield)> {
257245
let assemblyFormat = "attr-dict $bitfield";
258246
}
259247

260-
def ROCDL_SBarrierOp : ROCDL_Op<"s.barrier"> {
261-
string llvmBuilder = [{
262-
createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier);
263-
}];
248+
def ROCDL_SBarrierOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier", [], 0> {
264249
let assemblyFormat = "attr-dict";
265250
}
266251

@@ -276,68 +261,51 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
276261
let assemblyFormat = "attr-dict";
277262
}
278263

279-
def ROCDL_BarrierSignalOp : ROCDL_IntrOp<"s.barrier.signal", [], [], [], 0, 0, 0, [0], ["id"]>,
264+
def ROCDL_BarrierSignalOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.signal", [], 0, [0], ["id"]>,
280265
Arguments<(ins I32Attr:$id)> {
281266
let results = (outs);
282267
let assemblyFormat = "$id attr-dict";
283268
}
284269

285-
def ROCDL_BarrierWaitOp : ROCDL_IntrOp<"s.barrier.wait", [], [], [], 0, 0, 0, [0], ["id"]>,
270+
def ROCDL_BarrierWaitOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.wait", [], 0, [0], ["id"]>,
286271
Arguments<(ins I16Attr:$id)> {
287272
let results = (outs);
288273
let assemblyFormat = "$id attr-dict";
289-
string llvmBuilder =
290-
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier_wait,builder.getInt16(op.getId()));";
291274
}
292275

293-
def ROCDL_WaitDscntOp: ROCDL_IntrOp<"s.wait.dscnt", [], [], [], 0, 0, 0, [0], ["id"]>,
276+
def ROCDL_WaitDscntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.dscnt", [], 0, [0], ["id"]>,
294277
Arguments<(ins I16Attr:$id)> {
295278
let results = (outs);
296279
let assemblyFormat = "$id attr-dict";
297280
}
298281

299-
def ROCDL_SetPrioOp : ROCDL_IntrOp<"s.setprio", [], [], [], 0>,
282+
def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>,
300283
Arguments<(ins I16Attr:$priority)> {
301-
let results = (outs);
302284
let assemblyFormat = "$priority attr-dict";
303-
string llvmBuilder =
304-
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_setprio,builder.getInt16(op.getPriority()));";
305285
}
306286

307-
def ROCDL_SchedBarrier : ROCDL_IntrOp<"sched.barrier", [], [], [], 0>,
287+
def ROCDL_SchedBarrier : ROCDL_ConcreteNonMemIntrOp<"sched.barrier", [], 0, [0],["mask"]>,
308288
Arguments<(ins I32Attr:$mask)> {
309-
let results = (outs);
310289
let assemblyFormat = "$mask attr-dict";
311-
string llvmBuilder =
312-
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_sched_barrier,builder.getInt32(op.getMask()));";
313290
}
314291

315-
def ROCDL_SchedGroupBarrier : ROCDL_IntrOp<"sched.group.barrier", [], [], [], 0>,
316-
Arguments<(ins I32Attr:$mask, I32Attr:$size, I32Attr:$groupId)> {
317-
let results = (outs);
292+
def ROCDL_SchedGroupBarrier
293+
: ROCDL_ConcreteNonMemIntrOp<"sched.group.barrier", [], 0,
294+
[0, 1, 2], ["mask", "size", "groupId"]>,
295+
Arguments<(ins I32Attr:$mask, I32Attr:$size, I32Attr:$groupId)> {
318296
let assemblyFormat = "$mask `,` $size `,` $groupId attr-dict";
319-
string llvmBuilder = [{
320-
createIntrinsicCall(builder,
321-
llvm::Intrinsic::amdgcn_sched_group_barrier,
322-
{builder.getInt32(op.getMask()), builder.getInt32(op.getSize()), builder.getInt32(op.getGroupId())});
323-
}];
324297
}
325298

326-
def ROCDL_IglpOpt : ROCDL_IntrOp<"iglp.opt", [], [], [], 0>,
299+
def ROCDL_IglpOpt : ROCDL_ConcreteNonMemIntrOp<"iglp.opt", [], 0, [0], ["variant"]>,
327300
Arguments<(ins I32Attr:$variant)> {
328-
let results = (outs);
329301
let assemblyFormat = "$variant attr-dict";
330-
string llvmBuilder =
331-
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_iglp_opt,builder.getInt32(op.getVariant()));";
332302
}
333303

334304
//===---------------------------------------------------------------------===//
335305
// Xdlops intrinsics
336306

337307
class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
338-
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
339-
"amdgcn_" # !subst(".","_", mnemonic),
340-
[], [], traits, 1>,
308+
ROCDL_IntrOp<mnemonic, [], [], traits, 1>,
341309
Arguments<(ins Variadic<LLVM_Type>:$args)> {
342310
let assemblyFormat =
343311
"$args attr-dict `:` functional-type($args, $res)";
@@ -347,9 +315,7 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
347315
// MFMA intrinsics with overloaded operands
348316
class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
349317
list<Trait> traits = []> :
350-
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
351-
"amdgcn_" # !subst(".","_", mnemonic),
352-
[], overloadedOperands, traits, 1>,
318+
ROCDL_IntrOp<mnemonic, [], overloadedOperands, traits, 1>,
353319
Arguments<(ins Variadic<LLVM_Type>:$args)> {
354320
let assemblyFormat =
355321
"$args attr-dict `:` functional-type($args, $res)";
@@ -430,9 +396,7 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
430396
// WMMA intrinsics
431397
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
432398
list<Trait> traits = []> :
433-
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
434-
"amdgcn_" # !subst(".","_", mnemonic),
435-
[0], overloadedOperands, traits, 1>,
399+
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
436400
Arguments<(ins Variadic<LLVM_Type>:$args)> {
437401
let assemblyFormat =
438402
"$args attr-dict `:` functional-type($args, $res)";
@@ -572,50 +536,32 @@ def ROCDL_RawPtrBufferAtomicFaddOp : ROCDL_RawPtrBufferAtomicNoRet<"fadd">;
572536
// Raw buffer load/store intrinsics
573537

574538
def ROCDL_RawBufferLoadOp :
575-
ROCDL_Op<"raw.buffer.load">,
576-
Results<(outs LLVM_Type:$res)>,
539+
ROCDL_IntrOp<"raw.buffer.load", [0], [], [], 1>,
577540
Arguments<(ins LLVM_Type:$rsrc,
578541
LLVM_Type:$offset,
579542
LLVM_Type:$soffset,
580543
LLVM_Type:$aux)> {
581-
string llvmBuilder = [{
582-
$res = createIntrinsicCall(builder,
583-
llvm::Intrinsic::amdgcn_raw_buffer_load, {$rsrc, $offset,
584-
$soffset, $aux}, {$_resultType});
585-
}];
586544
let hasCustomAssemblyFormat = 1;
587545
}
588546

589547
def ROCDL_RawBufferStoreOp :
590-
ROCDL_Op<"raw.buffer.store">,
548+
ROCDL_IntrOp<"raw.buffer.store", [], [0], [], 0>,
591549
Arguments<(ins LLVM_Type:$vdata,
592550
LLVM_Type:$rsrc,
593551
LLVM_Type:$offset,
594552
LLVM_Type:$soffset,
595553
LLVM_Type:$aux)>{
596-
string llvmBuilder = [{
597-
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
598-
createIntrinsicCall(builder,
599-
llvm::Intrinsic::amdgcn_raw_buffer_store, {$vdata, $rsrc,
600-
$offset, $soffset, $aux}, {vdataType});
601-
}];
602554
let hasCustomAssemblyFormat = 1;
603555
}
604556

605557
def ROCDL_RawBufferAtomicCmpSwap :
606-
ROCDL_Op<"raw.buffer.atomic.cmpswap", [AllTypesMatch<["res", "src", "cmp"]>]>,
607-
Results<(outs LLVM_Type:$res)>,
558+
ROCDL_IntrOp<"raw.buffer.atomic.cmpswap", [], [0], [AllTypesMatch<["res", "src", "cmp"]>], 1>,
608559
Arguments<(ins LLVM_Type:$src,
609560
LLVM_Type:$cmp,
610561
LLVM_Type:$rsrc,
611562
I32:$offset,
612563
I32:$soffset,
613564
I32:$aux)>{
614-
string llvmBuilder = [{
615-
$res = createIntrinsicCall(builder,
616-
llvm::Intrinsic::amdgcn_raw_buffer_atomic_cmpswap, {$src, $cmp, $rsrc,
617-
$offset, $soffset, $aux}, {$_resultType});
618-
}];
619565
let assemblyFormat = [{
620566
attr-dict `(` operands `)` `:` type($res) `,` type($rsrc)
621567
}];
@@ -625,100 +571,64 @@ def ROCDL_RawBufferAtomicCmpSwap :
625571
// MI-100 and MI-200 buffer atomic floating point add intrinsic
626572

627573
def ROCDL_RawBufferAtomicFAddOp :
628-
ROCDL_Op<"raw.buffer.atomic.fadd">,
574+
ROCDL_IntrOp<"raw.buffer.atomic.fadd", [], [0], [], 0>,
629575
Arguments<(ins LLVM_Type:$vdata,
630576
LLVM_Type:$rsrc,
631577
LLVM_Type:$offset,
632578
LLVM_Type:$soffset,
633579
LLVM_Type:$aux)>{
634-
string llvmBuilder = [{
635-
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
636-
createIntrinsicCall(builder,
637-
llvm::Intrinsic::amdgcn_raw_buffer_atomic_fadd, {$vdata, $rsrc,
638-
$offset, $soffset, $aux}, {vdataType});
639-
}];
640580
let hasCustomAssemblyFormat = 1;
641581
}
642582

643583
//===---------------------------------------------------------------------===//
644584
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
645585

646586
def ROCDL_RawBufferAtomicFMaxOp :
647-
ROCDL_Op<"raw.buffer.atomic.fmax">,
587+
ROCDL_IntrOp<"raw.buffer.atomic.fmax", [], [0], [], 0>,
648588
Arguments<(ins LLVM_Type:$vdata,
649589
LLVM_Type:$rsrc,
650590
LLVM_Type:$offset,
651591
LLVM_Type:$soffset,
652592
LLVM_Type:$aux)>{
653-
string llvmBuilder = [{
654-
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
655-
createIntrinsicCall(builder,
656-
llvm::Intrinsic::amdgcn_raw_buffer_atomic_fmax, {$vdata, $rsrc,
657-
$offset, $soffset, $aux}, {vdataType});
658-
}];
659593
let hasCustomAssemblyFormat = 1;
660594
}
661595

662596
//===---------------------------------------------------------------------===//
663597
// Buffer atomic signed integer max intrinsic.
664598

665599
def ROCDL_RawBufferAtomicSMaxOp :
666-
ROCDL_Op<"raw.buffer.atomic.smax">,
600+
ROCDL_IntrOp<"raw.buffer.atomic.smax", [], [0], [], 0>,
667601
Arguments<(ins LLVM_Type:$vdata,
668602
LLVM_Type:$rsrc,
669603
LLVM_Type:$offset,
670604
LLVM_Type:$soffset,
671605
LLVM_Type:$aux)>{
672-
string llvmBuilder = [{
673-
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
674-
createIntrinsicCall(builder,
675-
llvm::Intrinsic::amdgcn_raw_buffer_atomic_smax, {$vdata, $rsrc,
676-
$offset, $soffset, $aux}, {vdataType});
677-
}];
678606
let hasCustomAssemblyFormat = 1;
679607
}
680608

681609
//===---------------------------------------------------------------------===//
682610
// Buffer atomic unsigned integer min intrinsic.
683611

684612
def ROCDL_RawBufferAtomicUMinOp :
685-
ROCDL_Op<"raw.buffer.atomic.umin">,
613+
ROCDL_IntrOp<"raw.buffer.atomic.umin", [], [0], [], 0>,
686614
Arguments<(ins LLVM_Type:$vdata,
687615
LLVM_Type:$rsrc,
688616
LLVM_Type:$offset,
689617
LLVM_Type:$soffset,
690618
LLVM_Type:$aux)>{
691-
string llvmBuilder = [{
692-
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
693-
createIntrinsicCall(builder,
694-
llvm::Intrinsic::amdgcn_raw_buffer_atomic_umin, {$vdata, $rsrc,
695-
$offset, $soffset, $aux}, {vdataType});
696-
}];
697619
let hasCustomAssemblyFormat = 1;
698620
}
699621

700622
// DPP Update intrinsic
701623
def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
702-
[AllTypesMatch<["res", "src", "old"]>], 1>,
624+
[AllTypesMatch<["res", "src", "old"]>], 1, 0, 0,
625+
[2, 3, 4, 5], ["dppCtrl", "rowMask", "bankMask", "boundCtrl"]>,
703626
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
704627
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
705628
let results = (outs LLVM_Type:$res);
706629
let assemblyFormat = [{
707630
attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
708631
}];
709-
string llvmBuilder = [{
710-
auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
711-
llvm::Value *args[] = {
712-
moduleTranslation.lookupValue(op.getOld()),
713-
moduleTranslation.lookupValue(op.getSrc()),
714-
builder.getInt32(op.getDppCtrl()),
715-
builder.getInt32(op.getRowMask()),
716-
builder.getInt32(op.getBankMask()),
717-
builder.getInt1(op.getBoundCtrl())
718-
};
719-
$res = createIntrinsicCall(builder,
720-
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
721-
}];
722632
}
723633

724634
//===---------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
343343
<< chipset.majorVersion;
344344

345345
Location loc = op->getLoc();
346-
rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
346+
rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
347347
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
348348
} else {
349349
Location loc = op->getLoc();

0 commit comments

Comments
 (0)