@@ -91,6 +91,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
91
91
Type llvmI32 = this ->typeConverter ->convertType (i32 );
92
92
Type llvmI16 = this ->typeConverter ->convertType (rewriter.getI16Type ());
93
93
94
+ auto toI32 = [&](Value val) -> Value {
95
+ if (val.getType () == llvmI32)
96
+ return val;
97
+
98
+ return rewriter.create <LLVM::TruncOp>(loc, llvmI32, val);
99
+ };
100
+
94
101
int64_t elementByteWidth = memrefType.getElementTypeBitWidth () / 8 ;
95
102
Value byteWidthConst = createI32Constant (rewriter, loc, elementByteWidth);
96
103
@@ -166,22 +173,22 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
166
173
Value stride = rewriter.create <LLVM::ConstantOp>(
167
174
loc, llvmI16, rewriter.getI16IntegerAttr (0 ));
168
175
Value numRecords;
169
- if (memrefType.hasStaticShape ()) {
176
+ if (memrefType.hasStaticShape () && memrefType. getLayout (). isIdentity () ) {
170
177
numRecords = createI32Constant (
171
178
rewriter, loc,
172
179
static_cast <int32_t >(memrefType.getNumElements () * elementByteWidth));
173
180
} else {
174
181
Value maxIndex;
175
182
for (uint32_t i = 0 , e = memrefType.getRank (); i < e; ++i) {
176
- Value size = memrefDescriptor.size (rewriter, loc, i);
177
- Value stride = memrefDescriptor.stride (rewriter, loc, i);
183
+ Value size = toI32 ( memrefDescriptor.size (rewriter, loc, i) );
184
+ Value stride = toI32 ( memrefDescriptor.stride (rewriter, loc, i) );
178
185
stride = rewriter.create <LLVM::MulOp>(loc, stride, byteWidthConst);
179
186
Value maxThisDim = rewriter.create <LLVM::MulOp>(loc, size, stride);
180
187
maxIndex = maxIndex ? rewriter.create <LLVM::MaximumOp>(loc, maxIndex,
181
188
maxThisDim)
182
189
: maxThisDim;
183
190
}
184
- numRecords = rewriter. create <LLVM::TruncOp>(loc, llvmI32, maxIndex) ;
191
+ numRecords = maxIndex;
185
192
}
186
193
187
194
// Flag word:
@@ -218,7 +225,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
218
225
Value strideOp;
219
226
if (ShapedType::isDynamic (strides[i])) {
220
227
strideOp = rewriter.create <LLVM::MulOp>(
221
- loc, memrefDescriptor.stride (rewriter, loc, i), byteWidthConst);
228
+ loc, toI32 (memrefDescriptor.stride (rewriter, loc, i)),
229
+ byteWidthConst);
222
230
} else {
223
231
strideOp =
224
232
createI32Constant (rewriter, loc, strides[i] * elementByteWidth);
@@ -240,7 +248,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
240
248
sgprOffset = createI32Constant (rewriter, loc, 0 );
241
249
if (ShapedType::isDynamic (offset))
242
250
sgprOffset = rewriter.create <LLVM::AddOp>(
243
- loc, memrefDescriptor.offset (rewriter, loc), sgprOffset);
251
+ loc, toI32 ( memrefDescriptor.offset (rewriter, loc) ), sgprOffset);
244
252
else if (offset > 0 )
245
253
sgprOffset = rewriter.create <LLVM::AddOp>(
246
254
loc, sgprOffset, createI32Constant (rewriter, loc, offset));
0 commit comments