Skip to content

[mlir][LLVM] Refactor how range() annotations are handled for ROCDL intrinsics #107658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,40 @@ def LLVM_TBAATagArrayAttr
let constBuilderCall = ?;
}

//===----------------------------------------------------------------------===//
// ConstantRangeAttr
//===----------------------------------------------------------------------===//
def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
let parameters = (ins
"::llvm::APInt":$lower,
"::llvm::APInt":$upper
);
let summary = "A range of two integers, corresponding to LLVM's ConstantRange";
let description = [{
A pair of two integers, mapping to the ConstantRange structure in LLVM IR,
which is allowed to wrap or be empty.

The range represented is [Lower, Upper), and is either signed or unsigned
depending on context.

`lower` and `upper` must have the same width.

Syntax:
```
`<` `i`(width($lower)) $lower `,` $upper `>`
}];

let builders = [
AttrBuilder<(ins "uint32_t":$bitWidth, "int64_t":$lower, "int64_t":$upper), [{
return $_get($_ctxt, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper));
}]>
];

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some round-trip tests in the LLVM dialect test folder? Exercising various of the supported bitwidth.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, we don't currently have the roundtrip tests set up for ROCDL at all ... should I add them in this PR? Or should I pick some intrinsic as a target for these annotations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also say that we do have import tests in the NVVM PR that's stacked on top of this one



//===----------------------------------------------------------------------===//
// VScaleRangeAttr
//===----------------------------------------------------------------------===//
Expand Down
37 changes: 31 additions & 6 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -319,17 +319,19 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
"StringLiteral(\"" # name # "\")"), ", ") # "}";
let llvmBuilder = [{
string baseLlvmBuilder = [{
auto *inst = LLVM::detail::createIntrinsicCall(
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
(void) inst;
}] # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
}];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
# !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
# !if(!gt(numResults, 0), "$res = inst;", "");
# baseLlvmBuilderCoda;

string mlirBuilder = [{
string baseMlirBuilder = [{
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
Expand All @@ -345,9 +347,32 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
}] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
auto op = $_builder.create<$_qualCppClassName>(
$_location, resultTypes, mlirOperands, mlirAttrs);
}] # !if(!gt(requiresFastmath, 0),
}];
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
# !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
# baseMlirBuilderCoda;

// Code for handling a `range` attribute that holds the constant range of the
// intrinsic's result (if one is specified at the call site). This is intended
// for GPU IDs and other calls where range() is meaningful. It expects
// an optional LLVM_ConstantRangeAttr named `range` to be present on the
// operation. These are included to abstract out common code in several
// dialects.
string setRangeRetAttrCode = [{
if ($range) {
inst->addRangeRetAttr(::llvm::ConstantRange(
$range->getLower(), $range->getUpper()));
}
}];
string importRangeRetAttrCode = [{
// Note: we don't want to look in to the declaration here.
auto rangeAttr = inst->getAttributes().getRetAttr(::llvm::Attribute::Range);
if (rangeAttr.isValid()) {
const ::llvm::ConstantRange& value = rangeAttr.getValueAsConstantRange();
op.setRangeAttr(::mlir::LLVM::ConstantRangeAttr::get($_builder.getContext(), value.getLower(), value.getUpper()));
}
}];
}

// Base class for LLVM intrinsic operations, should not be used directly. Places
Expand Down
61 changes: 37 additions & 24 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,36 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
// ROCDL special register op definitions
//===----------------------------------------------------------------------===//

class ROCDL_SpecialRegisterOp<string mnemonic,
list<Trait> traits = []> :
ROCDL_Op<mnemonic, !listconcat(traits, [Pure])>,
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
string llvmBuilder = "$res = createIntrinsicCallWithRange(builder,"
# "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic)
# ", op->getAttrOfType<::mlir::DenseI32ArrayAttr>(\"range\"));";
let assemblyFormat = "attr-dict `:` type($res)";
class ROCDL_SpecialIdRegisterOp<string mnemonic> :
ROCDL_IntrPure1Op<mnemonic>,
Arguments<(ins OptionalAttr<LLVM_ConstantRangeAttr>:$range)> {
string llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
string mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;

let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";

// Temporaly builder until Nvidia ops also support range attributes.
let builders = [
OpBuilder<(ins "Type":$resultType), [{
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
}]>
];
}

class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
int parameter, list<Trait> traits = []> :
ROCDL_Op<mnemonic, !listconcat(traits, [Pure])>,
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
Results<(outs LLVM_Type:$res)>, Arguments<(ins OptionalAttr<LLVM_ConstantRangeAttr>:$range)> {
string llvmBuilder = "$res = createDimGetterFunctionCall(builder, op, \""
# device_function # "\", " # parameter # ");";
let assemblyFormat = "attr-dict `:` type($res)";
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";

// Temporaly builder until Nvidia ops also support range attributes.
let builders = [
OpBuilder<(ins "Type":$resultType), [{
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
}]>
];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -181,33 +194,33 @@ def ROCDL_BallotOp :
//===----------------------------------------------------------------------===//
// Thread index and Block index

def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">;
def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">;
def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">;
def ROCDL_ThreadIdXOp : ROCDL_SpecialIdRegisterOp<"workitem.id.x">;
def ROCDL_ThreadIdYOp : ROCDL_SpecialIdRegisterOp<"workitem.id.y">;
def ROCDL_ThreadIdZOp : ROCDL_SpecialIdRegisterOp<"workitem.id.z">;

def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">;
def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">;
def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">;
def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">;
def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">;
def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">;

//===----------------------------------------------------------------------===//
// Thread range and Block range

def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x",
def ROCDL_BlockDimXOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.x",
"__ockl_get_local_size", 0>;

def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y",
def ROCDL_BlockDimYOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.y",
"__ockl_get_local_size", 1>;

def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z",
def ROCDL_BlockDimZOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.z",
"__ockl_get_local_size", 2>;

def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x",
def ROCDL_GridDimXOp : ROCDL_DimGetterFunctionOp<"grid.dim.x",
"__ockl_get_num_groups", 0>;

def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y",
def ROCDL_GridDimYOp : ROCDL_DimGetterFunctionOp<"grid.dim.y",
"__ockl_get_num_groups", 1>;

def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z",
def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z",
"__ockl_get_num_groups", 2>;

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {

if (upperBound && intrType != IntrType::None) {
int32_t min = (intrType == IntrType::Dim ? 1 : 0);
int32_t max = *upperBound - (intrType == IntrType::Id ? 0 : 1);
newOp->setAttr(
"range", DenseI32ArrayAttr::get(op.getContext(), ArrayRef{min, max}));
int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1);
newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
rewriter.getContext(), 32, min, max));
}
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
Expand Down
41 changes: 41 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,47 @@ DIRecursiveTypeAttrInterface DISubprogramAttr::getRecSelf(DistinctAttr recId) {
{}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {});
}

//===----------------------------------------------------------------------===//
// ConstantRangeAttr
//===----------------------------------------------------------------------===//

Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
llvm::SMLoc loc = parser.getCurrentLocation();
IntegerType widthType;
if (parser.parseLess() || parser.parseType(widthType) ||
parser.parseComma()) {
return Attribute{};
}
unsigned bitWidth = widthType.getWidth();
APInt lower(bitWidth, 0);
APInt upper(bitWidth, 0);
if (parser.parseInteger(lower) || parser.parseComma() ||
parser.parseInteger(upper) || parser.parseGreater())
return Attribute{};
// For some reason, 0 is always parsed as 64-bits, fix that if needed.
if (lower.isZero())
lower = lower.sextOrTrunc(bitWidth);
if (upper.isZero())
upper = upper.sextOrTrunc(bitWidth);
return parser.getChecked<ConstantRangeAttr>(loc, parser.getContext(), lower,
upper);
}

void ConstantRangeAttr::print(AsmPrinter &printer) const {
printer << "<i" << getLower().getBitWidth() << ", " << getLower() << ", "
<< getUpper() << ">";
}

LogicalResult
ConstantRangeAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
APInt lower, APInt upper) {
if (lower.getBitWidth() != upper.getBitWidth())
return emitError()
<< "expected lower and upper to have matching bitwidths but got "
<< lower.getBitWidth() << " vs. " << upper.getBitWidth();
return success();
}

//===----------------------------------------------------------------------===//
// TargetFeaturesAttr
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 15 additions & 20 deletions mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,13 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;

static llvm::Value *createIntrinsicCallWithRange(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
DenseI32ArrayAttr maybeRange) {
auto *inst = llvm::cast<llvm::CallInst>(
createIntrinsicCall(builder, intrinsic, {}, {}));
if (maybeRange) {
llvm::ConstantRange Range(APInt(32, maybeRange[0]),
APInt(32, maybeRange[1]));
inst->addRangeRetAttr(Range);
}
return inst;
}

// Create a call to ROCm-Device-Library function
// Currently this routine will work only for calling ROCDL functions that
// take a single int32 argument. It is likely that the interface of this
// function will change to make it more generic.
static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
StringRef fnName, int parameter) {
// Create a call to ROCm-Device-Library function that returns an ID.
// This is intended to specifically call device functions that fetch things like
// block or grid dimensions, and so is limited to functions that take one
// integer parameter.
static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder,
Operation *op, StringRef fnName,
int parameter) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::FunctionType *functionType = llvm::FunctionType::get(
llvm::Type::getInt64Ty(module->getContext()), // return type.
Expand All @@ -54,7 +42,14 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
module->getOrInsertFunction(fnName, functionType).getCallee());
llvm::Value *fnOp0 = llvm::ConstantInt::get(
llvm::Type::getInt32Ty(module->getContext()), parameter);
return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
auto *call = builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
// Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but
// these ockl functions are defined to be 64-bits
call->addRangeRetAttr(llvm::ConstantRange(rangeAttr.getLower().zext(64),
rangeAttr.getUpper().zext(64)));
}
return call;
}

namespace {
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,18 @@ gpu.module @test_module {
{known_block_size = array<i32: 8, 12, 16>,
known_grid_size = array<i32: 20, 24, 28>} {

// CHECK: rocdl.workitem.id.x {range = array<i32: 0, 8>} : i32
// CHECK: rocdl.workitem.id.x range <i32, 0, 8> : i32
%tIdX = gpu.thread_id x
// CHECK: rocdl.workitem.id.y {range = array<i32: 0, 12>} : i32
// CHECK: rocdl.workitem.id.y range <i32, 0, 12> : i32
%tIdY = gpu.thread_id y
// CHECK: rocdl.workitem.id.z {range = array<i32: 0, 16>} : i32
// CHECK: rocdl.workitem.id.z range <i32, 0, 16> : i32
%tIdZ = gpu.thread_id z

// CHECK: rocdl.workgroup.id.x {range = array<i32: 0, 20>} : i32
// CHECK: rocdl.workgroup.id.x range <i32, 0, 20> : i32
%bIdX = gpu.block_id x
// CHECK: rocdl.workgroup.id.y {range = array<i32: 0, 24>} : i32
// CHECK: rocdl.workgroup.id.y range <i32, 0, 24> : i32
%bIdY = gpu.block_id y
// CHECK: rocdl.workgroup.id.z {range = array<i32: 0, 28>} : i32
// CHECK: rocdl.workgroup.id.z range <i32, 0, 28> : i32
%bIdZ = gpu.block_id z

// "Usage" to make the ID calls not die
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ llvm.func @rocdl_special_regs() -> i32 {
%12 = rocdl.grid.dim.z : i64

// CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x()
%13 = rocdl.workitem.id.x {range = array<i32: 0, 64>} : i32
%13 = rocdl.workitem.id.x range <i32, 0, 64> : i32

// CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
%14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
llvm.return %1 : i32
}

Expand Down
Loading