Skip to content

Commit 90a0be9

Browse files
authored
[mlir][LLVM] Refactor how range() annotations are handled for ROCDL intrinsics (#107658)
This commit introduces a ConstantRange attribute to match the ConstantRange attribute type present in LLVM IR. It then refactors the LLVM_IntrOpBase so that the basic part of the intrinsic builder code can be re-used without needing to copy it or get rid of important context. This, along with adding code for handling an optional `range` attribute to that same base, allows us to make the support for range() annotations generic without adding another bit to IntrOpBase. This commit then updates the lowering of index intrinsic operations to use the new ConstantRange attribute and fixes a bug (where we'd be subtracting 1 from upper bounds instead of adding it on operations like gpu.block_dim) along the way. The point of these changes is to enable these range annotations to be used for the corresponding NVVM operations in a future commit.
1 parent a409ebc commit 90a0be9

File tree

8 files changed

+170
-60
lines changed

8 files changed

+170
-60
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,40 @@ def LLVM_TBAATagArrayAttr
10341034
let constBuilderCall = ?;
10351035
}
10361036

1037+
//===----------------------------------------------------------------------===//
1038+
// ConstantRangeAttr
1039+
//===----------------------------------------------------------------------===//
1040+
def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
1041+
let parameters = (ins
1042+
"::llvm::APInt":$lower,
1043+
"::llvm::APInt":$upper
1044+
);
1045+
let summary = "A range of two integers, corresponding to LLVM's ConstantRange";
1046+
let description = [{
1047+
A pair of two integers, mapping to the ConstantRange structure in LLVM IR,
1048+
which is allowed to wrap or be empty.
1049+
1050+
The range represented is [Lower, Upper), and is either signed or unsigned
1051+
depending on context.
1052+
1053+
`lower` and `upper` must have the same width.
1054+
1055+
Syntax:
1056+
```
1057+
`<` `i`(width($lower)) $lower `,` $upper `>`
1058+
}];
1059+
1060+
let builders = [
1061+
AttrBuilder<(ins "uint32_t":$bitWidth, "int64_t":$lower, "int64_t":$upper), [{
1062+
return $_get($_ctxt, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper));
1063+
}]>
1064+
];
1065+
1066+
let hasCustomAssemblyFormat = 1;
1067+
let genVerifyDecl = 1;
1068+
}
1069+
1070+
10371071
//===----------------------------------------------------------------------===//
10381072
// VScaleRangeAttr
10391073
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,17 +319,19 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
319319
string immArgPositionsCpp = "{" # !interleave(immArgPositions, ", ") # "}";
320320
string immArgAttrNamesCpp = "{" # !interleave(!foreach(name, immArgAttrNames,
321321
"StringLiteral(\"" # name # "\")"), ", ") # "}";
322-
let llvmBuilder = [{
322+
string baseLlvmBuilder = [{
323323
auto *inst = LLVM::detail::createIntrinsicCall(
324324
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
325325
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
326326
immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
327327
(void) inst;
328-
}] # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
328+
}];
329+
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
330+
let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
329331
# !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
330-
# !if(!gt(numResults, 0), "$res = inst;", "");
332+
# baseLlvmBuilderCoda;
331333

332-
string mlirBuilder = [{
334+
string baseMlirBuilder = [{
333335
SmallVector<Value> mlirOperands;
334336
SmallVector<NamedAttribute> mlirAttrs;
335337
if (failed(moduleImport.convertIntrinsicArguments(
@@ -345,9 +347,32 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
345347
}] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
346348
auto op = $_builder.create<$_qualCppClassName>(
347349
$_location, resultTypes, mlirOperands, mlirAttrs);
348-
}] # !if(!gt(requiresFastmath, 0),
350+
}];
351+
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
352+
let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
349353
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
350-
# !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
354+
# baseMlirBuilderCoda;
355+
356+
// Code for handling a `range` attribute that holds the constant range of the
357+
// intrinsic's result (if one is specified at the call site). This is intended
358+
// for GPU IDs and other calls where range() is meaningful. It expects
359+
// an optional LLVM_ConstantRangeAttr named `range` to be present on the
360+
// operation. These are included to abstract out common code in several
361+
// dialects.
362+
string setRangeRetAttrCode = [{
363+
if ($range) {
364+
inst->addRangeRetAttr(::llvm::ConstantRange(
365+
$range->getLower(), $range->getUpper()));
366+
}
367+
}];
368+
string importRangeRetAttrCode = [{
369+
// Note: we don't want to look in to the declaration here.
370+
auto rangeAttr = inst->getAttributes().getRetAttr(::llvm::Attribute::Range);
371+
if (rangeAttr.isValid()) {
372+
const ::llvm::ConstantRange& value = rangeAttr.getValueAsConstantRange();
373+
op.setRangeAttr(::mlir::LLVM::ConstantRangeAttr::get($_builder.getContext(), value.getLower(), value.getUpper()));
374+
}
375+
}];
351376
}
352377

353378
// Base class for LLVM intrinsic operations, should not be used directly. Places

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,36 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
9898
// ROCDL special register op definitions
9999
//===----------------------------------------------------------------------===//
100100

101-
class ROCDL_SpecialRegisterOp<string mnemonic,
102-
list<Trait> traits = []> :
103-
ROCDL_Op<mnemonic, !listconcat(traits, [Pure])>,
104-
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
105-
string llvmBuilder = "$res = createIntrinsicCallWithRange(builder,"
106-
# "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic)
107-
# ", op->getAttrOfType<::mlir::DenseI32ArrayAttr>(\"range\"));";
108-
let assemblyFormat = "attr-dict `:` type($res)";
101+
class ROCDL_SpecialIdRegisterOp<string mnemonic> :
102+
ROCDL_IntrPure1Op<mnemonic>,
103+
Arguments<(ins OptionalAttr<LLVM_ConstantRangeAttr>:$range)> {
104+
string llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
105+
string mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
106+
107+
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
108+
109+
// Temporaly builder until Nvidia ops also support range attributes.
110+
let builders = [
111+
OpBuilder<(ins "Type":$resultType), [{
112+
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
113+
}]>
114+
];
109115
}
110116

111-
class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
117+
class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
112118
int parameter, list<Trait> traits = []> :
113119
ROCDL_Op<mnemonic, !listconcat(traits, [Pure])>,
114-
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
115-
string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
120+
Results<(outs LLVM_Type:$res)>, Arguments<(ins OptionalAttr<LLVM_ConstantRangeAttr>:$range)> {
121+
string llvmBuilder = "$res = createDimGetterFunctionCall(builder, op, \""
116122
# device_function # "\", " # parameter # ");";
117-
let assemblyFormat = "attr-dict `:` type($res)";
123+
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
124+
125+
// Temporaly builder until Nvidia ops also support range attributes.
126+
let builders = [
127+
OpBuilder<(ins "Type":$resultType), [{
128+
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
129+
}]>
130+
];
118131
}
119132

120133
//===----------------------------------------------------------------------===//
@@ -181,33 +194,33 @@ def ROCDL_BallotOp :
181194
//===----------------------------------------------------------------------===//
182195
// Thread index and Block index
183196

184-
def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">;
185-
def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">;
186-
def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">;
197+
def ROCDL_ThreadIdXOp : ROCDL_SpecialIdRegisterOp<"workitem.id.x">;
198+
def ROCDL_ThreadIdYOp : ROCDL_SpecialIdRegisterOp<"workitem.id.y">;
199+
def ROCDL_ThreadIdZOp : ROCDL_SpecialIdRegisterOp<"workitem.id.z">;
187200

188-
def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">;
189-
def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">;
190-
def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">;
201+
def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">;
202+
def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">;
203+
def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">;
191204

192205
//===----------------------------------------------------------------------===//
193206
// Thread range and Block range
194207

195-
def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x",
208+
def ROCDL_BlockDimXOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.x",
196209
"__ockl_get_local_size", 0>;
197210

198-
def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y",
211+
def ROCDL_BlockDimYOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.y",
199212
"__ockl_get_local_size", 1>;
200213

201-
def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z",
214+
def ROCDL_BlockDimZOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.z",
202215
"__ockl_get_local_size", 2>;
203216

204-
def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x",
217+
def ROCDL_GridDimXOp : ROCDL_DimGetterFunctionOp<"grid.dim.x",
205218
"__ockl_get_num_groups", 0>;
206219

207-
def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y",
220+
def ROCDL_GridDimYOp : ROCDL_DimGetterFunctionOp<"grid.dim.y",
208221
"__ockl_get_num_groups", 1>;
209222

210-
def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z",
223+
def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z",
211224
"__ockl_get_num_groups", 2>;
212225

213226
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
114114

115115
if (upperBound && intrType != IntrType::None) {
116116
int32_t min = (intrType == IntrType::Dim ? 1 : 0);
117-
int32_t max = *upperBound - (intrType == IntrType::Id ? 0 : 1);
118-
newOp->setAttr(
119-
"range", DenseI32ArrayAttr::get(op.getContext(), ArrayRef{min, max}));
117+
int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1);
118+
newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
119+
rewriter.getContext(), 32, min, max));
120120
}
121121
if (indexBitwidth > 32) {
122122
newOp = rewriter.create<LLVM::SExtOp>(

mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,47 @@ DIRecursiveTypeAttrInterface DISubprogramAttr::getRecSelf(DistinctAttr recId) {
232232
{}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {});
233233
}
234234

235+
//===----------------------------------------------------------------------===//
236+
// ConstantRangeAttr
237+
//===----------------------------------------------------------------------===//
238+
239+
Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
240+
llvm::SMLoc loc = parser.getCurrentLocation();
241+
IntegerType widthType;
242+
if (parser.parseLess() || parser.parseType(widthType) ||
243+
parser.parseComma()) {
244+
return Attribute{};
245+
}
246+
unsigned bitWidth = widthType.getWidth();
247+
APInt lower(bitWidth, 0);
248+
APInt upper(bitWidth, 0);
249+
if (parser.parseInteger(lower) || parser.parseComma() ||
250+
parser.parseInteger(upper) || parser.parseGreater())
251+
return Attribute{};
252+
// For some reason, 0 is always parsed as 64-bits, fix that if needed.
253+
if (lower.isZero())
254+
lower = lower.sextOrTrunc(bitWidth);
255+
if (upper.isZero())
256+
upper = upper.sextOrTrunc(bitWidth);
257+
return parser.getChecked<ConstantRangeAttr>(loc, parser.getContext(), lower,
258+
upper);
259+
}
260+
261+
void ConstantRangeAttr::print(AsmPrinter &printer) const {
262+
printer << "<i" << getLower().getBitWidth() << ", " << getLower() << ", "
263+
<< getUpper() << ">";
264+
}
265+
266+
LogicalResult
267+
ConstantRangeAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
268+
APInt lower, APInt upper) {
269+
if (lower.getBitWidth() != upper.getBitWidth())
270+
return emitError()
271+
<< "expected lower and upper to have matching bitwidths but got "
272+
<< lower.getBitWidth() << " vs. " << upper.getBitWidth();
273+
return success();
274+
}
275+
235276
//===----------------------------------------------------------------------===//
236277
// TargetFeaturesAttr
237278
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,13 @@ using namespace mlir;
2626
using namespace mlir::LLVM;
2727
using mlir::LLVM::detail::createIntrinsicCall;
2828

29-
static llvm::Value *createIntrinsicCallWithRange(llvm::IRBuilderBase &builder,
30-
llvm::Intrinsic::ID intrinsic,
31-
DenseI32ArrayAttr maybeRange) {
32-
auto *inst = llvm::cast<llvm::CallInst>(
33-
createIntrinsicCall(builder, intrinsic, {}, {}));
34-
if (maybeRange) {
35-
llvm::ConstantRange Range(APInt(32, maybeRange[0]),
36-
APInt(32, maybeRange[1]));
37-
inst->addRangeRetAttr(Range);
38-
}
39-
return inst;
40-
}
41-
42-
// Create a call to ROCm-Device-Library function
43-
// Currently this routine will work only for calling ROCDL functions that
44-
// take a single int32 argument. It is likely that the interface of this
45-
// function will change to make it more generic.
46-
static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
47-
StringRef fnName, int parameter) {
29+
// Create a call to ROCm-Device-Library function that returns an ID.
30+
// This is intended to specifically call device functions that fetch things like
31+
// block or grid dimensions, and so is limited to functions that take one
32+
// integer parameter.
33+
static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder,
34+
Operation *op, StringRef fnName,
35+
int parameter) {
4836
llvm::Module *module = builder.GetInsertBlock()->getModule();
4937
llvm::FunctionType *functionType = llvm::FunctionType::get(
5038
llvm::Type::getInt64Ty(module->getContext()), // return type.
@@ -54,7 +42,14 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
5442
module->getOrInsertFunction(fnName, functionType).getCallee());
5543
llvm::Value *fnOp0 = llvm::ConstantInt::get(
5644
llvm::Type::getInt32Ty(module->getContext()), parameter);
57-
return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
45+
auto *call = builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
46+
if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
47+
// Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but
48+
// these ockl functions are defined to be 64-bits
49+
call->addRangeRetAttr(llvm::ConstantRange(rangeAttr.getLower().zext(64),
50+
rangeAttr.getUpper().zext(64)));
51+
}
52+
return call;
5853
}
5954

6055
namespace {

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,18 @@ gpu.module @test_module {
7777
{known_block_size = array<i32: 8, 12, 16>,
7878
known_grid_size = array<i32: 20, 24, 28>} {
7979

80-
// CHECK: rocdl.workitem.id.x {range = array<i32: 0, 8>} : i32
80+
// CHECK: rocdl.workitem.id.x range <i32, 0, 8> : i32
8181
%tIdX = gpu.thread_id x
82-
// CHECK: rocdl.workitem.id.y {range = array<i32: 0, 12>} : i32
82+
// CHECK: rocdl.workitem.id.y range <i32, 0, 12> : i32
8383
%tIdY = gpu.thread_id y
84-
// CHECK: rocdl.workitem.id.z {range = array<i32: 0, 16>} : i32
84+
// CHECK: rocdl.workitem.id.z range <i32, 0, 16> : i32
8585
%tIdZ = gpu.thread_id z
8686

87-
// CHECK: rocdl.workgroup.id.x {range = array<i32: 0, 20>} : i32
87+
// CHECK: rocdl.workgroup.id.x range <i32, 0, 20> : i32
8888
%bIdX = gpu.block_id x
89-
// CHECK: rocdl.workgroup.id.y {range = array<i32: 0, 24>} : i32
89+
// CHECK: rocdl.workgroup.id.y range <i32, 0, 24> : i32
9090
%bIdY = gpu.block_id y
91-
// CHECK: rocdl.workgroup.id.z {range = array<i32: 0, 28>} : i32
91+
// CHECK: rocdl.workgroup.id.z range <i32, 0, 28> : i32
9292
%bIdZ = gpu.block_id z
9393

9494
// "Usage" to make the ID calls not die

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ llvm.func @rocdl_special_regs() -> i32 {
2828
%12 = rocdl.grid.dim.z : i64
2929

3030
// CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x()
31-
%13 = rocdl.workitem.id.x {range = array<i32: 0, 64>} : i32
31+
%13 = rocdl.workitem.id.x range <i32, 0, 64> : i32
3232

33+
// CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
34+
%14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
3335
llvm.return %1 : i32
3436
}
3537

0 commit comments

Comments
 (0)