Skip to content

[mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR. #89233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -2343,6 +2343,8 @@ def int_amdgcn_buffer_wbinvl1_vol :
// VI Intrinsics
//===----------------------------------------------------------------------===//

// The llvm.amdgcn.mov.dpp.i32 intrinsic represents the mov.dpp operation in AMDGPU.
// This operation is being deprecated and can be replaced with llvm.amdgcn.update.dpp.i32.
// llvm.amdgcn.mov.dpp.i32 <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
def int_amdgcn_mov_dpp :
Intrinsic<[llvm_anyint_ty],
Expand All @@ -2352,6 +2354,10 @@ def int_amdgcn_mov_dpp :
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, IntrNoCallback, IntrNoFree]>;

// The llvm.amdgcn.update.dpp.i32 intrinsic represents the update.dpp operation in AMDGPU.
// It takes an old value, a source operand, a DPP control operand, a row mask, a bank mask, and a bound control.
// This operation is equivalent to a sequence of v_mov_b32 operations.
// It is preferred over llvm.amdgcn.mov.dpp.i32 for future use.
// llvm.amdgcn.update.dpp.i32 <old> <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
// Should be equivalent to:
// v_mov_b32 <dest> <old>
Expand Down
55 changes: 55 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,61 @@ def AMDGPU_RawBufferAtomicUminOp :
let hasVerifier = 1;
}

def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
"The possible permutations for a DPP operation",
[
I32EnumAttrCase<"quad_perm", 0>,
I32EnumAttrCase<"row_shl", 1>,
I32EnumAttrCase<"row_shr", 2>,
I32EnumAttrCase<"row_ror", 3>,
I32EnumAttrCase<"wave_shl", 4>,
I32EnumAttrCase<"wave_shr", 5>,
I32EnumAttrCase<"wave_ror", 6>,
I32EnumAttrCase<"wave_rol", 7>,
I32EnumAttrCase<"row_mirror", 8>,
I32EnumAttrCase<"row_half_mirror", 9>,
I32EnumAttrCase<"row_bcast_15", 10>,
I32EnumAttrCase<"row_bcast_31", 11>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::amdgpu";
}

def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
"dpp_perm">;

def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
Arguments<(ins AnyType:$old,
AnyType:$src,
AMDGPU_DPPPermAttr:$kind,
OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
let summary = "AMDGPU DPP operation";
let description = [{
This operation represents DPP functionality in a GPU program.
DPP provides the following operations:
- Full crossbar in a group of four (`quad_perm`)
- Wavefront shift left by one lane (`wave_shl`)
- Wavefront shift right by one lane (`wave_shr`)
- Wavefront rotate right by one lane (`wave_ror`)
- Wavefront rotate left by one lane (`wave_rol`)
- Row shift left by 1–15 lanes (`row_shl`)
- Row shift right by 1–15 lanes (`row_shr`)
- Row rotate right by 1–15 lanes (`row_ror`)
- Reverse within a row (`row_mirror`)
- Reverse within a half-row (`row_half_mirror`)
- Broadcast the 15th lane of each row to the next row (`row_bcast`)
- Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
}];
let results = (outs AnyType:$result);
let assemblyFormat = [{
$old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result)
}];
let hasVerifier = 1;
}

def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,30 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}

// DPP Update intrinsic
def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
[AllTypesMatch<["res", "src", "old"]>], 1>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
}];
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
llvm::Value *args[] = {
moduleTranslation.lookupValue(op.getOld()),
moduleTranslation.lookupValue(op.getSrc()),
builder.getInt32(op.getDppCtrl()),
builder.getInt32(op.getRowMask()),
builder.getInt32(op.getBankMask()),
builder.getInt1(op.getBoundCtrl())
};
$res = createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
}];
}

//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
Expand Down
156 changes: 153 additions & 3 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,155 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
return success();
}

// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
// operation into the corresponding ROCDL instructions.
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Convert the source operand to the corresponding LLVM type
Location loc = DppOp.getLoc();
Value src = adaptor.getSrc();
Value old = adaptor.getOld();
Type srcType = src.getType();
Type oldType = old.getType();
Type llvmType = nullptr;
if (srcType.getIntOrFloatBitWidth() < 32) {
llvmType = rewriter.getI32Type();
} else if (isa<FloatType>(srcType)) {
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
? rewriter.getF32Type()
: rewriter.getF64Type();
} else if (isa<IntegerType>(srcType)) {
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
? rewriter.getI32Type()
: rewriter.getI64Type();
}
auto llvmSrcIntType = typeConverter->convertType(
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));

// If the source type is less of 32, use bitcast to convert it to i32.
auto convertOperand = [&](Value operand, Type operandType) {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
}
return operand;
};

src = convertOperand(src, srcType);
old = convertOperand(old, oldType);

// This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
enum DppCtrl : unsigned {
ROW_SHL0 = 0x100,
ROW_SHR0 = 0x110,
ROW_ROR0 = 0x120,
WAVE_SHL1 = 0x130,
WAVE_ROL1 = 0x134,
WAVE_SHR1 = 0x138,
WAVE_ROR1 = 0x13C,
ROW_MIRROR = 0x140,
ROW_HALF_MIRROR = 0x141,
BCAST15 = 0x142,
BCAST31 = 0x143,
};

auto kind = DppOp.getKind();
auto permArgument = DppOp.getPermArgument();
uint32_t DppCtrl = 0;

switch (kind) {

case DPPPerm::quad_perm:
if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
int32_t i = 0;
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
uint32_t num = elem.getInt();
DppCtrl |= num << (i * 2);
i++;
}
}
break;
case DPPPerm::row_shl:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
}
break;
case DPPPerm::row_shr:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
}
break;
case DPPPerm::row_ror:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
}
break;
case DPPPerm::wave_shl:
DppCtrl = DppCtrl::WAVE_SHL1;
break;
case DPPPerm::wave_shr:
DppCtrl = DppCtrl::WAVE_SHR1;
break;
case DPPPerm::wave_rol:
DppCtrl = DppCtrl::WAVE_ROL1;
break;
case DPPPerm::wave_ror:
DppCtrl = DppCtrl::WAVE_ROR1;
break;
case DPPPerm::row_mirror:
DppCtrl = DppCtrl::ROW_MIRROR;
break;
case DPPPerm::row_half_mirror:
DppCtrl = DppCtrl::ROW_HALF_MIRROR;
break;
case DPPPerm::row_bcast_15:
DppCtrl = DppCtrl::BCAST15;
break;
case DPPPerm::row_bcast_31:
DppCtrl = DppCtrl::BCAST31;
break;
}

// Check for row_mask, bank_mask, bound_ctrl if they exist and create
// constants
auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();

// create a ROCDL_DPPMovOp instruction with the appropriate attributes
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);

Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
}
}

// We are replacing the AMDGPU_DPPOp instruction with the new
// ROCDL_DPPMovOp instruction
rewriter.replaceOp(DppOp, ValueRange(result));
return success();
}
};

struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
ConvertAMDGPUToROCDLPass() = default;
Expand Down Expand Up @@ -895,9 +1044,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering,
WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering>(converter, chipset);
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
chipset);
}

std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
Expand Down
62 changes: 62 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,68 @@ LogicalResult MFMAOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// DPPOp
//===----------------------------------------------------------------------===//
LogicalResult DPPOp::verify() {
Type srcType = getSrc().getType();
if (srcType.getIntOrFloatBitWidth() > 64) {
return emitOpError("integer and floating point types larger than 64 bits "
"are not supported");
}

DPPPerm kind = getKind();
Attribute permArgument = getPermArgument().value_or(Attribute{});

switch (kind) {

case DPPPerm::quad_perm: {
auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
if (!quadPermAttr || quadPermAttr.size() != 4) {
return emitOpError("quad_perm attribute must have exactly 4 elements");
}
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
uint32_t num = elem.getInt();
if (num < 0 || num > 3) {
return emitOpError(
"Each element of quad_perm must be in the range [0, 3]");
}
}
} break;

case DPPPerm::row_shl:
case DPPPerm::row_shr:
case DPPPerm::row_ror: {
if (!permArgument) {
return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
"' value not specified");
}
if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
uint32_t attrValue = intAttr.getInt();
if (attrValue < 1 || attrValue > 15) {
return emitOpError("Attribute value must be between 1 and 15");
}
}
} break;

case DPPPerm::wave_shl:
case DPPPerm::wave_shr:
case DPPPerm::wave_rol:
case DPPPerm::wave_ror:
case DPPPerm::row_mirror:
case DPPPerm::row_half_mirror:
case DPPPerm::row_bcast_15:
case DPPPerm::row_bcast_31: {
if (permArgument && !isa<UnitAttr>(permArgument)) {
return emitOpError("Expected unit attribute for permArgument, but found "
"non-trivial argument");
}
break;
}
}
return success();
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
Loading
Loading