Skip to content

Commit bfa501b

Browse files
committed
[mlir][AMDGPU] Move to new buffer resource intrinsics
The AMDGPU backend now has buffer resource intrinsics that take a ptr addrspase (8) instead of a vector<4xi32>, improving LLVM's ability to reason about their memory behavior. This commit moves MLIR to these new functions. Reviewed By: jsjodin Differential Revision: https://reviews.llvm.org/D157053
1 parent f71ba7d commit bfa501b

File tree

6 files changed

+330
-129
lines changed

6 files changed

+330
-129
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,34 @@ def LLVM_AnyFloat : Type<
5555
def LLVM_AnyPointer : Type<CPred<"::llvm::isa<::mlir::LLVM::LLVMPointerType>($_self)">,
5656
"LLVM pointer type", "::mlir::LLVM::LLVMPointerType">;
5757

58+
def LLVM_OpaquePointer : Type<
59+
And<[LLVM_AnyPointer.predicate,
60+
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).isOpaque()">]>,
61+
"LLVM opaque pointer", "::mlir::LLVM::LLVMPointerType">;
62+
5863
// Type constraint accepting LLVM pointer type with an additional constraint
5964
// on the element type.
6065
class LLVM_PointerTo<Type pointee> : Type<
6166
And<[LLVM_AnyPointer.predicate,
62-
Or<[CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).isOpaque()">,
67+
Or<[LLVM_OpaquePointer.predicate,
6368
SubstLeaves<
6469
"$_self",
6570
"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).getElementType()",
6671
pointee.predicate>]>]>,
6772
"LLVM pointer to " # pointee.summary, "::mlir::LLVM::LLVMPointerType">;
6873

74+
// Opaque pointer in a given address space.
75+
class LLVM_OpaquePointerInAddressSpace<int addressSpace> : Type<
76+
And<[LLVM_OpaquePointer.predicate,
77+
CPred<
78+
"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).getAddressSpace() == "
79+
# addressSpace>]>,
80+
"Opaque LLVM pointer in address space " # addressSpace,
81+
"::mlir::LLVM::LLVMPointerType"> {
82+
let builderCall = "$_builder.getType<::mlir::LLVM::LLVMPointerType>("
83+
# addressSpace # ")";
84+
}
85+
6986
// Type constraints accepting LLVM pointer type to integer of a specific width.
7087
class LLVM_IntPtrBase<int width, int addressSpace = 0> : Type<
7188
And<[LLVM_PointerTo<I<width>>.predicate,

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ class ROCDL_IntrPure1Op<string mnemonic> :
7171
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
7272
"amdgcn_" # !subst(".", "_", mnemonic), [], [], [Pure], 1>;
7373

74+
class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
75+
list<int> overloadedOperands, list<Trait> traits, int numResults,
76+
int requiresAccessGroup = 0, int requiresAliasAnalysis = 0> :
77+
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
78+
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
79+
overloadedOperands, traits, numResults, requiresAccessGroup,
80+
requiresAliasAnalysis>;
81+
7482
//===----------------------------------------------------------------------===//
7583
// ROCDL special register op definitions
7684
//===----------------------------------------------------------------------===//
@@ -262,7 +270,96 @@ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16">
262270
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8">;
263271
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4">;
264272

273+
//===---------------------------------------------------------------------===//
274+
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
275+
// raw buffer mode).
276+
//===---------------------------------------------------------------------===//
277+
278+
def ROCDLBufferRsrc : LLVM_OpaquePointerInAddressSpace<8>;
279+
280+
def ROCDL_MakeBufferRsrcOp :
281+
ROCDL_IntrOp<"make.buffer.rsrc", [], [0], [Pure], 1>,
282+
Arguments<(ins LLVM_AnyPointer:$base,
283+
I16:$stride,
284+
I32:$numRecords,
285+
I32:$flags)> {
286+
let results = (outs ROCDLBufferRsrc:$res);
287+
let assemblyFormat = "operands attr-dict `:` type($base) `to` type($res)";
288+
}
289+
290+
def ROCDL_RawPtrBufferLoadOp :
291+
ROCDL_IntrOp<"raw.ptr.buffer.load", [0], [], [], 1, 0, 1> {
292+
dag args = (ins Arg<ROCDLBufferRsrc, "", [MemRead]>:$rsrc,
293+
I32:$offset,
294+
I32:$soffset,
295+
I32:$aux);
296+
let arguments = !con(args, aliasAttrs);
297+
let assemblyFormat = "operands attr-dict `:` type($res)";
298+
let extraClassDefinition = [{
299+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
300+
return {getRes()};
301+
}
302+
}];
303+
}
304+
305+
def ROCDL_RawPtrBufferStoreOp :
306+
ROCDL_IntrOp<"raw.ptr.buffer.store", [], [0], [], 0, 0, 1> {
307+
dag args = (ins LLVM_Type:$vdata,
308+
Arg<ROCDLBufferRsrc, "", [MemWrite]>:$rsrc,
309+
I32:$offset,
310+
I32:$soffset,
311+
I32:$aux);
312+
let arguments = !con(args, aliasAttrs);
313+
let assemblyFormat = "operands attr-dict `:` type($vdata)";
314+
let extraClassDefinition = [{
315+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
316+
return {getRsrc()};
317+
}
318+
}];
319+
320+
}
321+
322+
def ROCDL_RawPtrBufferAtomicCmpSwap :
323+
ROCDL_IntrOp<"raw.ptr.buffer.atomic.cmpswap",
324+
[0], [], [AllTypesMatch<["res", "src", "cmp"]>], 1, 0, 1> {
325+
dag args = (ins LLVM_Type:$src,
326+
LLVM_Type:$cmp,
327+
Arg<ROCDLBufferRsrc, "", [MemRead, MemWrite]>:$rsrc,
328+
I32:$offset,
329+
I32:$soffset,
330+
I32:$aux);
331+
let arguments = !con(args, aliasAttrs);
332+
let assemblyFormat = "operands attr-dict `:` type($res)";
333+
let extraClassDefinition = [{
334+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
335+
return {getRsrc()};
336+
}
337+
}];
338+
}
339+
340+
class ROCDL_RawPtrBufferAtomicNoRet<string op> :
341+
ROCDL_IntrOp<"raw.ptr.buffer.atomic." # op, [], [0], [], 0, 0, 1> {
342+
dag args = (ins LLVM_Type:$vdata,
343+
Arg<ROCDLBufferRsrc, "", [MemRead, MemWrite]>:$rsrc,
344+
I32:$offset,
345+
I32:$soffset,
346+
I32:$aux);
347+
let arguments = !con(args, aliasAttrs);
348+
let assemblyFormat = "operands attr-dict `:` type($vdata)";
349+
let extraClassDefinition = [{
350+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
351+
return {getRsrc()};
352+
}
353+
}];
354+
}
355+
356+
def ROCDL_RawPtrBufferAtomicFmaxOp : ROCDL_RawPtrBufferAtomicNoRet<"fmax">;
357+
def ROCDL_RawPtrBufferAtomicSmaxOp : ROCDL_RawPtrBufferAtomicNoRet<"smax">;
358+
def ROCDL_RawPtrBufferAtomicUminOp : ROCDL_RawPtrBufferAtomicNoRet<"umin">;
359+
// Note: not supported on all architectures
360+
def ROCDL_RawPtrBufferAtomicFaddOp : ROCDL_RawPtrBufferAtomicNoRet<"fadd">;
265361

362+
/// LEGACY BUFFER OPERATIONS. DO NOT USE IN NEW CODE. KEPT FOR IR COMPATIBILITY.
266363
//===---------------------------------------------------------------------===//
267364
// Vector buffer load/store intrinsics
268365

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
5959
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
6060

6161
if (chipset.majorVersion < 9)
62-
return gpuOp.emitOpError("Raw buffer ops require GCN or higher");
62+
return gpuOp.emitOpError("raw buffer ops require GCN or higher");
6363

6464
Value storeData = adaptor.getODSOperands(0)[0];
6565
if (storeData == memref) // no write component to this op
@@ -82,6 +82,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
8282

8383
Type i32 = rewriter.getI32Type();
8484
Type llvmI32 = this->typeConverter->convertType(i32);
85+
Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
8586

8687
int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
8788
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
@@ -156,41 +157,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
156157
if (failed(getStridesAndOffset(memrefType, strides, offset)))
157158
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
158159

159-
// Resource descriptor
160-
// bits 0-47: base address
161-
// bits 48-61: stride (0 for raw buffers)
162-
// bit 62: texture cache coherency (always 0)
163-
// bit 63: enable swizzles (always off for raw buffers)
164-
// bits 64-95 (word 2): Number of records, units of stride
165-
// bits 96-127 (word 3): See below
166-
167-
Type llvm4xI32 = this->typeConverter->convertType(VectorType::get(4, i32));
168160
MemRefDescriptor memrefDescriptor(memref);
169-
Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
170-
Value c32I64 = rewriter.create<LLVM::ConstantOp>(
171-
loc, llvmI64, rewriter.getI64IntegerAttr(32));
172-
173-
Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);
174161

175162
Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
176-
Value ptrAsInt = rewriter.create<LLVM::PtrToIntOp>(loc, llvmI64, ptr);
177-
Value lowHalf = rewriter.create<LLVM::TruncOp>(loc, llvmI32, ptrAsInt);
178-
resource = rewriter.create<LLVM::InsertElementOp>(
179-
loc, llvm4xI32, resource, lowHalf,
180-
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0));
181-
182-
// Bits 48-63 are used both for the stride of the buffer and (on gfx10) for
183-
// enabling swizzling. Prevent the high bits of pointers from accidentally
184-
// setting those flags.
185-
Value highHalfShifted = rewriter.create<LLVM::TruncOp>(
186-
loc, llvmI32, rewriter.create<LLVM::LShrOp>(loc, ptrAsInt, c32I64));
187-
Value highHalfTruncated = rewriter.create<LLVM::AndOp>(
188-
loc, llvmI32, highHalfShifted,
189-
createI32Constant(rewriter, loc, 0x0000ffff));
190-
resource = rewriter.create<LLVM::InsertElementOp>(
191-
loc, llvm4xI32, resource, highHalfTruncated,
192-
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 1));
193-
163+
// The stride value is always 0 for raw buffers. This also disables
164+
// swizling.
165+
Value stride = rewriter.createOrFold<LLVM::ConstantOp>(
166+
loc, llvmI16, rewriter.getI16IntegerAttr(0));
194167
Value numRecords;
195168
if (memrefType.hasStaticShape()) {
196169
numRecords = createI32Constant(
@@ -209,11 +182,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
209182
}
210183
numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
211184
}
212-
resource = rewriter.create<LLVM::InsertElementOp>(
213-
loc, llvm4xI32, resource, numRecords,
214-
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 2));
215185

216-
// Final word:
186+
// Flag word:
217187
// bits 0-11: dst sel, ignored by these intrinsics
218188
// bits 12-14: data format (ignored, must be nonzero, 7=float)
219189
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
@@ -227,16 +197,16 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
227197
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
228198
// none, 3 = either swizzles or testing against offset field) RDNA only
229199
// bits 30-31: Type (must be 0)
230-
uint32_t word3 = (7 << 12) | (4 << 15);
200+
uint32_t flags = (7 << 12) | (4 << 15);
231201
if (chipset.majorVersion >= 10) {
232-
word3 |= (1 << 24);
202+
flags |= (1 << 24);
233203
uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
234-
word3 |= (oob << 28);
204+
flags |= (oob << 28);
235205
}
236-
Value word3Const = createI32Constant(rewriter, loc, word3);
237-
resource = rewriter.create<LLVM::InsertElementOp>(
238-
loc, llvm4xI32, resource, word3Const,
239-
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 3));
206+
Value flagsConst = createI32Constant(rewriter, loc, flags);
207+
Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
208+
Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
209+
loc, rsrcType, ptr, stride, numRecords, flagsConst);
240210
args.push_back(resource);
241211

242212
// Indexing (voffset)
@@ -708,16 +678,20 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
708678
});
709679

710680
patterns.add<LDSBarrierOpLowering>(converter);
711-
patterns.add<
712-
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
713-
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
714-
RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>,
715-
RawBufferOpLowering<RawBufferAtomicFmaxOp, ROCDL::RawBufferAtomicFMaxOp>,
716-
RawBufferOpLowering<RawBufferAtomicSmaxOp, ROCDL::RawBufferAtomicSMaxOp>,
717-
RawBufferOpLowering<RawBufferAtomicUminOp, ROCDL::RawBufferAtomicUMinOp>,
718-
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
719-
ROCDL::RawBufferAtomicCmpSwap>,
720-
MFMAOpLowering, WMMAOpLowering>(converter, chipset);
681+
patterns
682+
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
683+
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
684+
RawBufferOpLowering<RawBufferAtomicFaddOp,
685+
ROCDL::RawPtrBufferAtomicFaddOp>,
686+
RawBufferOpLowering<RawBufferAtomicFmaxOp,
687+
ROCDL::RawPtrBufferAtomicFmaxOp>,
688+
RawBufferOpLowering<RawBufferAtomicSmaxOp,
689+
ROCDL::RawPtrBufferAtomicSmaxOp>,
690+
RawBufferOpLowering<RawBufferAtomicUminOp,
691+
ROCDL::RawPtrBufferAtomicUminOp>,
692+
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
693+
ROCDL::RawPtrBufferAtomicCmpSwap>,
694+
MFMAOpLowering, WMMAOpLowering>(converter, chipset);
721695
}
722696

723697
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {

0 commit comments

Comments
 (0)