Skip to content

[mlir][ROCDL][~NFC] Migrate to LLVM dialect default builders #125609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 41 additions & 131 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
overloadedOperands, traits, numResults, requiresAccessGroup,
requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;

// Subclass to save typing and ease readibility when there aren't overloaded
// operands or memory accesses.
class ROCDL_ConcreteNonMemIntrOp<string mnemonic, list<Trait> traits,
int numResults, list<int> immArgPositions = [],
list<string> immArgNames = []>
: ROCDL_IntrOp<mnemonic, [], [], traits, numResults, 0, 0,
immArgPositions, immArgNames>;
//===----------------------------------------------------------------------===//
// ROCDL special register op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -150,37 +157,26 @@ class ROCDL_MbcntOp<string mnemonic> :
def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;

def ROCDL_DsSwizzleOp :
ROCDL_Op<"ds_swizzle">,
Results<(outs I32:$res)>,
Arguments<(ins I32:$src,
I32:$offset)>
{
string llvmBuilder = [{
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_swizzle, {$src, $offset});
}];
def ROCDL_DsSwizzleOp : ROCDL_ConcreteNonMemIntrOp<"ds_swizzle", [], 1>,
Arguments<(ins I32:$src,
I32:$offset)> {
let results = (outs I32:$res);
let assemblyFormat = [{
$src `,` $offset attr-dict `:` `(` type($src) `,` type($offset) `)` `->` type($res)
}];
}

def ROCDL_DsBpermuteOp :
ROCDL_Op<"ds_bpermute">,
Results<(outs I32:$res)>,
Arguments<(ins I32:$index,
I32:$src)>
{
string llvmBuilder = [{
$res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_ds_bpermute, {$index, $src});
}];
def ROCDL_DsBpermuteOp : ROCDL_ConcreteNonMemIntrOp<"ds_bpermute", [], 1>,
Arguments<(ins I32:$index,
I32:$src)> {
let results = (outs I32:$res);
let assemblyFormat = [{
$index `,` $src attr-dict `:` `(` type($index) `,` type($src) `)` `->` type($res)
}];
}

def ROCDL_BallotOp :
ROCDL_Op<"ballot">,
Results<(outs LLVM_Type:$res)>,
ROCDL_IntrOp<"ballot", [0], [], [], 1>,
Arguments<(ins I1:$pred)> {
let summary = "Vote across thread group";

Expand All @@ -189,11 +185,6 @@ def ROCDL_BallotOp :
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];

string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_ballot, {$pred}, {$_resultType});
}];

let assemblyFormat = "$pred attr-dict `:` type($res)";
}

Expand Down Expand Up @@ -249,18 +240,12 @@ def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z",

// Emits the waintcnt instruction. The bitfield's semantics depend
// on the target chipset
def ROCDL_WaitcntOp : ROCDL_Op<"waitcnt">, Arguments<(ins I32Attr:$bitfield)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_waitcnt,
{builder.getInt32($bitfield)});
}];
def ROCDL_SWaitcntOp : ROCDL_ConcreteNonMemIntrOp<"s.waitcnt", [], 0, [0], ["bitfield"]>,
Arguments<(ins I32Attr:$bitfield)> {
let assemblyFormat = "attr-dict $bitfield";
}

def ROCDL_SBarrierOp : ROCDL_Op<"s.barrier"> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier);
}];
def ROCDL_SBarrierOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier", [], 0> {
let assemblyFormat = "attr-dict";
}

Expand All @@ -276,68 +261,51 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
let assemblyFormat = "attr-dict";
}

def ROCDL_BarrierSignalOp : ROCDL_IntrOp<"s.barrier.signal", [], [], [], 0, 0, 0, [0], ["id"]>,
def ROCDL_BarrierSignalOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.signal", [], 0, [0], ["id"]>,
Arguments<(ins I32Attr:$id)> {
let results = (outs);
let assemblyFormat = "$id attr-dict";
}

def ROCDL_BarrierWaitOp : ROCDL_IntrOp<"s.barrier.wait", [], [], [], 0, 0, 0, [0], ["id"]>,
def ROCDL_BarrierWaitOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.wait", [], 0, [0], ["id"]>,
Arguments<(ins I16Attr:$id)> {
let results = (outs);
let assemblyFormat = "$id attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_barrier_wait,builder.getInt16(op.getId()));";
}

def ROCDL_WaitDscntOp: ROCDL_IntrOp<"s.wait.dscnt", [], [], [], 0, 0, 0, [0], ["id"]>,
def ROCDL_WaitDscntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.dscnt", [], 0, [0], ["id"]>,
Arguments<(ins I16Attr:$id)> {
let results = (outs);
let assemblyFormat = "$id attr-dict";
}

def ROCDL_SetPrioOp : ROCDL_IntrOp<"s.setprio", [], [], [], 0>,
def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>,
Arguments<(ins I16Attr:$priority)> {
let results = (outs);
let assemblyFormat = "$priority attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_s_setprio,builder.getInt16(op.getPriority()));";
}

def ROCDL_SchedBarrier : ROCDL_IntrOp<"sched.barrier", [], [], [], 0>,
def ROCDL_SchedBarrier : ROCDL_ConcreteNonMemIntrOp<"sched.barrier", [], 0, [0],["mask"]>,
Arguments<(ins I32Attr:$mask)> {
let results = (outs);
let assemblyFormat = "$mask attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_sched_barrier,builder.getInt32(op.getMask()));";
}

def ROCDL_SchedGroupBarrier : ROCDL_IntrOp<"sched.group.barrier", [], [], [], 0>,
Arguments<(ins I32Attr:$mask, I32Attr:$size, I32Attr:$groupId)> {
let results = (outs);
def ROCDL_SchedGroupBarrier
: ROCDL_ConcreteNonMemIntrOp<"sched.group.barrier", [], 0,
[0, 1, 2], ["mask", "size", "groupId"]>,
Arguments<(ins I32Attr:$mask, I32Attr:$size, I32Attr:$groupId)> {
let assemblyFormat = "$mask `,` $size `,` $groupId attr-dict";
string llvmBuilder = [{
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_sched_group_barrier,
{builder.getInt32(op.getMask()), builder.getInt32(op.getSize()), builder.getInt32(op.getGroupId())});
}];
}

def ROCDL_IglpOpt : ROCDL_IntrOp<"iglp.opt", [], [], [], 0>,
def ROCDL_IglpOpt : ROCDL_ConcreteNonMemIntrOp<"iglp.opt", [], 0, [0], ["variant"]>,
Arguments<(ins I32Attr:$variant)> {
let results = (outs);
let assemblyFormat = "$variant attr-dict";
string llvmBuilder =
"createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_iglp_opt,builder.getInt32(op.getVariant()));";
}

//===---------------------------------------------------------------------===//
// Xdlops intrinsics

class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".","_", mnemonic),
[], [], traits, 1>,
ROCDL_IntrOp<mnemonic, [], [], traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
Expand All @@ -347,9 +315,7 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
// MFMA intrinsics with overloaded operands
class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
list<Trait> traits = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".","_", mnemonic),
[], overloadedOperands, traits, 1>,
ROCDL_IntrOp<mnemonic, [], overloadedOperands, traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
Expand Down Expand Up @@ -430,9 +396,7 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
// WMMA intrinsics
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
list<Trait> traits = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".","_", mnemonic),
[0], overloadedOperands, traits, 1>,
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
Expand Down Expand Up @@ -572,50 +536,32 @@ def ROCDL_RawPtrBufferAtomicFaddOp : ROCDL_RawPtrBufferAtomicNoRet<"fadd">;
// Raw buffer load/store intrinsics

def ROCDL_RawBufferLoadOp :
ROCDL_Op<"raw.buffer.load">,
Results<(outs LLVM_Type:$res)>,
ROCDL_IntrOp<"raw.buffer.load", [0], [], [], 1>,
Arguments<(ins LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)> {
string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_load, {$rsrc, $offset,
$soffset, $aux}, {$_resultType});
}];
let hasCustomAssemblyFormat = 1;
}

def ROCDL_RawBufferStoreOp :
ROCDL_Op<"raw.buffer.store">,
ROCDL_IntrOp<"raw.buffer.store", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_store, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

def ROCDL_RawBufferAtomicCmpSwap :
ROCDL_Op<"raw.buffer.atomic.cmpswap", [AllTypesMatch<["res", "src", "cmp"]>]>,
Results<(outs LLVM_Type:$res)>,
ROCDL_IntrOp<"raw.buffer.atomic.cmpswap", [], [0], [AllTypesMatch<["res", "src", "cmp"]>], 1>,
Arguments<(ins LLVM_Type:$src,
LLVM_Type:$cmp,
LLVM_Type:$rsrc,
I32:$offset,
I32:$soffset,
I32:$aux)>{
string llvmBuilder = [{
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_cmpswap, {$src, $cmp, $rsrc,
$offset, $soffset, $aux}, {$_resultType});
}];
let assemblyFormat = [{
attr-dict `(` operands `)` `:` type($res) `,` type($rsrc)
}];
Expand All @@ -625,100 +571,64 @@ def ROCDL_RawBufferAtomicCmpSwap :
// MI-100 and MI-200 buffer atomic floating point add intrinsic

def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
ROCDL_IntrOp<"raw.buffer.atomic.fadd", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_fadd, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.

def ROCDL_RawBufferAtomicFMaxOp :
ROCDL_Op<"raw.buffer.atomic.fmax">,
ROCDL_IntrOp<"raw.buffer.atomic.fmax", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_fmax, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// Buffer atomic signed integer max intrinsic.

def ROCDL_RawBufferAtomicSMaxOp :
ROCDL_Op<"raw.buffer.atomic.smax">,
ROCDL_IntrOp<"raw.buffer.atomic.smax", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_smax, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// Buffer atomic unsigned integer min intrinsic.

def ROCDL_RawBufferAtomicUMinOp :
ROCDL_Op<"raw.buffer.atomic.umin">,
ROCDL_IntrOp<"raw.buffer.atomic.umin", [], [0], [], 0>,
Arguments<(ins LLVM_Type:$vdata,
LLVM_Type:$rsrc,
LLVM_Type:$offset,
LLVM_Type:$soffset,
LLVM_Type:$aux)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_raw_buffer_atomic_umin, {$vdata, $rsrc,
$offset, $soffset, $aux}, {vdataType});
}];
let hasCustomAssemblyFormat = 1;
}

// DPP Update intrinsic
def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
[AllTypesMatch<["res", "src", "old"]>], 1>,
[AllTypesMatch<["res", "src", "old"]>], 1, 0, 0,
[2, 3, 4, 5], ["dppCtrl", "rowMask", "bankMask", "boundCtrl"]>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
}];
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
llvm::Value *args[] = {
moduleTranslation.lookupValue(op.getOld()),
moduleTranslation.lookupValue(op.getSrc()),
builder.getInt32(op.getDppCtrl()),
builder.getInt32(op.getRowMask()),
builder.getInt32(op.getBankMask()),
builder.getInt1(op.getBoundCtrl())
};
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
}];
}

//===---------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;

Location loc = op->getLoc();
rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
Expand Down
Loading