@@ -59,7 +59,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
59
59
MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType ());
60
60
61
61
if (chipset.majorVersion < 9 )
62
- return gpuOp.emitOpError (" Raw buffer ops require GCN or higher" );
62
+ return gpuOp.emitOpError (" raw buffer ops require GCN or higher" );
63
63
64
64
Value storeData = adaptor.getODSOperands (0 )[0 ];
65
65
if (storeData == memref) // no write component to this op
@@ -82,6 +82,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
82
82
83
83
Type i32 = rewriter.getI32Type ();
84
84
Type llvmI32 = this ->typeConverter ->convertType (i32 );
85
+ Type llvmI16 = this ->typeConverter ->convertType (rewriter.getI16Type ());
85
86
86
87
int64_t elementByteWidth = memrefType.getElementTypeBitWidth () / 8 ;
87
88
Value byteWidthConst = createI32Constant (rewriter, loc, elementByteWidth);
@@ -156,41 +157,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
156
157
if (failed (getStridesAndOffset (memrefType, strides, offset)))
157
158
return gpuOp.emitOpError (" Can't lower non-stride-offset memrefs" );
158
159
159
- // Resource descriptor
160
- // bits 0-47: base address
161
- // bits 48-61: stride (0 for raw buffers)
162
- // bit 62: texture cache coherency (always 0)
163
- // bit 63: enable swizzles (always off for raw buffers)
164
- // bits 64-95 (word 2): Number of records, units of stride
165
- // bits 96-127 (word 3): See below
166
-
167
- Type llvm4xI32 = this ->typeConverter ->convertType (VectorType::get (4 , i32 ));
168
160
MemRefDescriptor memrefDescriptor (memref);
169
- Type llvmI64 = this ->typeConverter ->convertType (rewriter.getI64Type ());
170
- Value c32I64 = rewriter.create <LLVM::ConstantOp>(
171
- loc, llvmI64, rewriter.getI64IntegerAttr (32 ));
172
-
173
- Value resource = rewriter.create <LLVM::UndefOp>(loc, llvm4xI32);
174
161
175
162
Value ptr = memrefDescriptor.alignedPtr (rewriter, loc);
176
- Value ptrAsInt = rewriter.create <LLVM::PtrToIntOp>(loc, llvmI64, ptr);
177
- Value lowHalf = rewriter.create <LLVM::TruncOp>(loc, llvmI32, ptrAsInt);
178
- resource = rewriter.create <LLVM::InsertElementOp>(
179
- loc, llvm4xI32, resource, lowHalf,
180
- this ->createIndexAttrConstant (rewriter, loc, this ->getIndexType (), 0 ));
181
-
182
- // Bits 48-63 are used both for the stride of the buffer and (on gfx10) for
183
- // enabling swizzling. Prevent the high bits of pointers from accidentally
184
- // setting those flags.
185
- Value highHalfShifted = rewriter.create <LLVM::TruncOp>(
186
- loc, llvmI32, rewriter.create <LLVM::LShrOp>(loc, ptrAsInt, c32I64));
187
- Value highHalfTruncated = rewriter.create <LLVM::AndOp>(
188
- loc, llvmI32, highHalfShifted,
189
- createI32Constant (rewriter, loc, 0x0000ffff ));
190
- resource = rewriter.create <LLVM::InsertElementOp>(
191
- loc, llvm4xI32, resource, highHalfTruncated,
192
- this ->createIndexAttrConstant (rewriter, loc, this ->getIndexType (), 1 ));
193
-
163
+ // The stride value is always 0 for raw buffers. This also disables
164
+ // swizling.
165
+ Value stride = rewriter.createOrFold <LLVM::ConstantOp>(
166
+ loc, llvmI16, rewriter.getI16IntegerAttr (0 ));
194
167
Value numRecords;
195
168
if (memrefType.hasStaticShape ()) {
196
169
numRecords = createI32Constant (
@@ -209,11 +182,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
209
182
}
210
183
numRecords = rewriter.create <LLVM::TruncOp>(loc, llvmI32, maxIndex);
211
184
}
212
- resource = rewriter.create <LLVM::InsertElementOp>(
213
- loc, llvm4xI32, resource, numRecords,
214
- this ->createIndexAttrConstant (rewriter, loc, this ->getIndexType (), 2 ));
215
185
216
- // Final word:
186
+ // Flag word:
217
187
// bits 0-11: dst sel, ignored by these intrinsics
218
188
// bits 12-14: data format (ignored, must be nonzero, 7=float)
219
189
// bits 15-18: data format (ignored, must be nonzero, 4=32bit)
@@ -227,16 +197,16 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
227
197
// bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
228
198
// none, 3 = either swizzles or testing against offset field) RDNA only
229
199
// bits 30-31: Type (must be 0)
230
- uint32_t word3 = (7 << 12 ) | (4 << 15 );
200
+ uint32_t flags = (7 << 12 ) | (4 << 15 );
231
201
if (chipset.majorVersion >= 10 ) {
232
- word3 |= (1 << 24 );
202
+ flags |= (1 << 24 );
233
203
uint32_t oob = adaptor.getBoundsCheck () ? 3 : 2 ;
234
- word3 |= (oob << 28 );
204
+ flags |= (oob << 28 );
235
205
}
236
- Value word3Const = createI32Constant (rewriter, loc, word3 );
237
- resource = rewriter. create < LLVM::InsertElementOp>(
238
- loc, llvm4xI32, resource, word3Const,
239
- this -> createIndexAttrConstant (rewriter, loc, this -> getIndexType (), 3 ) );
206
+ Value flagsConst = createI32Constant (rewriter, loc, flags );
207
+ Type rsrcType = LLVM::LLVMPointerType::get (rewriter. getContext (), 8 );
208
+ Value resource = rewriter. createOrFold <ROCDL::MakeBufferRsrcOp>(
209
+ loc, rsrcType, ptr, stride, numRecords, flagsConst );
240
210
args.push_back (resource);
241
211
242
212
// Indexing (voffset)
@@ -708,16 +678,20 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
708
678
});
709
679
710
680
patterns.add <LDSBarrierOpLowering>(converter);
711
- patterns.add <
712
- RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,
713
- RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawBufferStoreOp>,
714
- RawBufferOpLowering<RawBufferAtomicFaddOp, ROCDL::RawBufferAtomicFAddOp>,
715
- RawBufferOpLowering<RawBufferAtomicFmaxOp, ROCDL::RawBufferAtomicFMaxOp>,
716
- RawBufferOpLowering<RawBufferAtomicSmaxOp, ROCDL::RawBufferAtomicSMaxOp>,
717
- RawBufferOpLowering<RawBufferAtomicUminOp, ROCDL::RawBufferAtomicUMinOp>,
718
- RawBufferOpLowering<RawBufferAtomicCmpswapOp,
719
- ROCDL::RawBufferAtomicCmpSwap>,
720
- MFMAOpLowering, WMMAOpLowering>(converter, chipset);
681
+ patterns
682
+ .add <RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
683
+ RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
684
+ RawBufferOpLowering<RawBufferAtomicFaddOp,
685
+ ROCDL::RawPtrBufferAtomicFaddOp>,
686
+ RawBufferOpLowering<RawBufferAtomicFmaxOp,
687
+ ROCDL::RawPtrBufferAtomicFmaxOp>,
688
+ RawBufferOpLowering<RawBufferAtomicSmaxOp,
689
+ ROCDL::RawPtrBufferAtomicSmaxOp>,
690
+ RawBufferOpLowering<RawBufferAtomicUminOp,
691
+ ROCDL::RawPtrBufferAtomicUminOp>,
692
+ RawBufferOpLowering<RawBufferAtomicCmpswapOp,
693
+ ROCDL::RawPtrBufferAtomicCmpSwap>,
694
+ MFMAOpLowering, WMMAOpLowering>(converter, chipset);
721
695
}
722
696
723
697
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass () {
0 commit comments