@@ -30,10 +30,23 @@ namespace mlir {
30
30
using namespace mlir ;
31
31
using namespace mlir ::amdgpu;
32
32
33
+ // / Convert an unsigned number `val` to i32.
34
+ static Value convertUnsignedToI32 (ConversionPatternRewriter &rewriter,
35
+ Location loc, Value val) {
36
+ IntegerType i32 = rewriter.getI32Type ();
37
+ // Force check that `val` is of int type.
38
+ auto valTy = cast<IntegerType>(val.getType ());
39
+ if (i32 == valTy)
40
+ return val;
41
+ return valTy.getWidth () > 32
42
+ ? Value (rewriter.create <LLVM::TruncOp>(loc, i32 , val))
43
+ : Value (rewriter.create <LLVM::ZExtOp>(loc, i32 , val));
44
+ }
45
+
33
46
static Value createI32Constant (ConversionPatternRewriter &rewriter,
34
47
Location loc, int32_t value) {
35
- Type llvmI32 = rewriter.getI32Type ();
36
- return rewriter.create <LLVM::ConstantOp>(loc, llvmI32 , value);
48
+ Type i32 = rewriter.getI32Type ();
49
+ return rewriter.create <LLVM::ConstantOp>(loc, i32 , value);
37
50
}
38
51
39
52
static Value createI1Constant (ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,27 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
42
55
return rewriter.create <LLVM::ConstantOp>(loc, llvmI1, value);
43
56
}
44
57
58
+ // / Returns the linear index used to access an element in the memref.
59
+ static Value getLinearIndexI32 (ConversionPatternRewriter &rewriter,
60
+ Location loc, MemRefDescriptor &memRefDescriptor,
61
+ ValueRange indices, ArrayRef<int64_t > strides) {
62
+ IntegerType i32 = rewriter.getI32Type ();
63
+ Value index;
64
+ for (auto [i, increment, stride] : llvm::enumerate (indices, strides)) {
65
+ if (stride != 1 ) { // Skip if stride is 1.
66
+ Value strideValue =
67
+ ShapedType::isDynamic (stride)
68
+ ? convertUnsignedToI32 (rewriter, loc,
69
+ memRefDescriptor.stride (rewriter, loc, i))
70
+ : rewriter.create <LLVM::ConstantOp>(loc, i32 , stride);
71
+ increment = rewriter.create <LLVM::MulOp>(loc, increment, strideValue);
72
+ }
73
+ index =
74
+ index ? rewriter.create <LLVM::AddOp>(loc, index, increment) : increment;
75
+ }
76
+ return index ? index : createI32Constant (rewriter, loc, 0 );
77
+ }
78
+
45
79
namespace {
46
80
// Define commonly used chipsets versions for convenience.
47
81
constexpr Chipset kGfx908 = Chipset(9 , 0 , 8 );
@@ -88,17 +122,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
88
122
Type llvmWantedDataType = this ->typeConverter ->convertType (wantedDataType);
89
123
90
124
Type i32 = rewriter.getI32Type ();
91
- Type llvmI32 = this ->typeConverter ->convertType (i32 );
92
- Type llvmI16 = this ->typeConverter ->convertType (rewriter.getI16Type ());
125
+ Type i16 = rewriter.getI16Type ();
93
126
94
- auto toI32 = [&](Value val) -> Value {
95
- if (val.getType () == llvmI32)
96
- return val;
97
-
98
- return rewriter.create <LLVM::TruncOp>(loc, llvmI32, val);
99
- };
100
-
101
- int64_t elementByteWidth = memrefType.getElementTypeBitWidth () / 8 ;
127
+ // Get the type size in bytes.
128
+ DataLayout dataLayout = DataLayout::closest (gpuOp);
129
+ int64_t elementByteWidth =
130
+ dataLayout.getTypeSizeInBits (memrefType.getElementType ()) / 8 ;
102
131
Value byteWidthConst = createI32Constant (rewriter, loc, elementByteWidth);
103
132
104
133
// If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +143,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
114
143
}
115
144
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
116
145
uint32_t vecLen = dataVector.getNumElements ();
117
- uint32_t elemBits = dataVector.getElementTypeBitWidth ();
146
+ uint32_t elemBits =
147
+ dataLayout.getTypeSizeInBits (dataVector.getElementType ());
118
148
uint32_t totalBits = elemBits * vecLen;
119
149
bool usePackedFp16 =
120
150
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2 ;
@@ -167,28 +197,36 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
167
197
168
198
MemRefDescriptor memrefDescriptor (memref);
169
199
170
- Value ptr = memrefDescriptor.alignedPtr (rewriter, loc);
200
+ Value ptr = memrefDescriptor.bufferPtr (
201
+ rewriter, loc, *this ->getTypeConverter (), memrefType);
171
202
// The stride value is always 0 for raw buffers. This also disables
172
203
// swizling.
173
204
Value stride = rewriter.create <LLVM::ConstantOp>(
174
- loc, llvmI16, rewriter.getI16IntegerAttr (0 ));
205
+ loc, i16 , rewriter.getI16IntegerAttr (0 ));
206
+ // Get the number of elements.
175
207
Value numRecords;
176
- if (memrefType.hasStaticShape () && memrefType.getLayout ().isIdentity ()) {
177
- numRecords = createI32Constant (
178
- rewriter, loc,
179
- static_cast <int32_t >(memrefType.getNumElements () * elementByteWidth));
208
+ if (memrefType.hasStaticShape () &&
209
+ !llvm::any_of (strides, ShapedType::isDynamic)) {
210
+ int64_t size = memrefType.getRank () == 0 ? 1 : 0 ;
211
+ ArrayRef<int64_t > shape = memrefType.getShape ();
212
+ for (uint32_t i = 0 , e = memrefType.getRank (); i < e; ++i)
213
+ size = std::max (shape[i] * strides[i], size);
214
+ size = size * elementByteWidth;
215
+ assert (size < std::numeric_limits<uint32_t >::max () &&
216
+ " the memref buffer is too large" );
217
+ numRecords = createI32Constant (rewriter, loc, static_cast <int32_t >(size));
180
218
} else {
181
219
Value maxIndex;
182
220
for (uint32_t i = 0 , e = memrefType.getRank (); i < e; ++i) {
183
- Value size = toI32 (memrefDescriptor.size (rewriter, loc, i));
184
- Value stride = toI32 (memrefDescriptor.stride (rewriter, loc, i));
185
- stride = rewriter.create <LLVM::MulOp>(loc, stride, byteWidthConst);
221
+ Value size = memrefDescriptor.size (rewriter, loc, i);
222
+ Value stride = memrefDescriptor.stride (rewriter, loc, i);
186
223
Value maxThisDim = rewriter.create <LLVM::MulOp>(loc, size, stride);
187
- maxIndex = maxIndex ? rewriter. create <LLVM::MaximumOp>(loc, maxIndex,
188
- maxThisDim)
189
- : maxThisDim;
224
+ maxIndex =
225
+ maxIndex ? rewriter. create <LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
226
+ : maxThisDim;
190
227
}
191
- numRecords = maxIndex;
228
+ numRecords = rewriter.create <LLVM::MulOp>(
229
+ loc, convertUnsignedToI32 (rewriter, loc, maxIndex), byteWidthConst);
192
230
}
193
231
194
232
// Flag word:
@@ -218,40 +256,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
218
256
args.push_back (resource);
219
257
220
258
// Indexing (voffset)
221
- Value voffset = createI32Constant (rewriter, loc, 0 );
222
- for (auto pair : llvm::enumerate (adaptor.getIndices ())) {
223
- size_t i = pair.index ();
224
- Value index = pair.value ();
225
- Value strideOp;
226
- if (ShapedType::isDynamic (strides[i])) {
227
- strideOp = rewriter.create <LLVM::MulOp>(
228
- loc, toI32 (memrefDescriptor.stride (rewriter, loc, i)),
229
- byteWidthConst);
230
- } else {
231
- strideOp =
232
- createI32Constant (rewriter, loc, strides[i] * elementByteWidth);
233
- }
234
- index = rewriter.create <LLVM::MulOp>(loc, index, strideOp);
235
- voffset = rewriter.create <LLVM::AddOp>(loc, voffset, index);
236
- }
237
- if (adaptor.getIndexOffset ()) {
238
- int32_t indexOffset = *gpuOp.getIndexOffset () * elementByteWidth;
239
- Value extraOffsetConst = createI32Constant (rewriter, loc, indexOffset);
259
+ Value voffset = getLinearIndexI32 (rewriter, loc, memrefDescriptor,
260
+ adaptor.getIndices (), strides);
261
+ if (std::optional<int32_t > indexOffset = adaptor.getIndexOffset ();
262
+ indexOffset && *indexOffset > 0 ) {
263
+ Value extraOffsetConst = createI32Constant (rewriter, loc, *indexOffset);
240
264
voffset =
241
265
voffset ? rewriter.create <LLVM::AddOp>(loc, voffset, extraOffsetConst)
242
266
: extraOffsetConst;
243
267
}
268
+ voffset = rewriter.create <LLVM::MulOp>(loc, voffset, byteWidthConst);
244
269
args.push_back (voffset);
245
270
271
+ // SGPR offset.
246
272
Value sgprOffset = adaptor.getSgprOffset ();
247
273
if (!sgprOffset)
248
274
sgprOffset = createI32Constant (rewriter, loc, 0 );
249
- if (ShapedType::isDynamic (offset))
250
- sgprOffset = rewriter.create <LLVM::AddOp>(
251
- loc, toI32 (memrefDescriptor.offset (rewriter, loc)), sgprOffset);
252
- else if (offset > 0 )
253
- sgprOffset = rewriter.create <LLVM::AddOp>(
254
- loc, sgprOffset, createI32Constant (rewriter, loc, offset));
275
+ sgprOffset = rewriter.create <LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
255
276
args.push_back (sgprOffset);
256
277
257
278
// bit 0: GLC = 0 (atomics drop value, less coherency)
0 commit comments