@@ -60,11 +60,6 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
60
60
builder.getIndexAttr (value));
61
61
}
62
62
63
- Value ConvertToLLVMPattern::createIndexConstant (
64
- ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
65
- return createIndexAttrConstant (builder, loc, getIndexType (), value);
66
- }
67
-
68
63
Value ConvertToLLVMPattern::getStridedElementPtr (
69
64
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
70
65
ConversionPatternRewriter &rewriter) const {
@@ -79,13 +74,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
79
74
Value base =
80
75
memRefDescriptor.bufferPtr (rewriter, loc, *getTypeConverter (), type);
81
76
77
+ Type indexType = getIndexType ();
82
78
Value index;
83
79
for (int i = 0 , e = indices.size (); i < e; ++i) {
84
80
Value increment = indices[i];
85
81
if (strides[i] != 1 ) { // Skip if stride is 1.
86
- Value stride = ShapedType::isDynamic (strides[i])
87
- ? memRefDescriptor.stride (rewriter, loc, i)
88
- : createIndexConstant (rewriter, loc, strides[i]);
82
+ Value stride =
83
+ ShapedType::isDynamic (strides[i])
84
+ ? memRefDescriptor.stride (rewriter, loc, i)
85
+ : createIndexAttrConstant (rewriter, loc, indexType, strides[i]);
89
86
increment = rewriter.create <LLVM::MulOp>(loc, increment, stride);
90
87
}
91
88
index =
@@ -130,15 +127,17 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
130
127
131
128
sizes.reserve (memRefType.getRank ());
132
129
unsigned dynamicIndex = 0 ;
130
+ Type indexType = getIndexType ();
133
131
for (int64_t size : memRefType.getShape ()) {
134
- sizes.push_back (size == ShapedType::kDynamic
135
- ? dynamicSizes[dynamicIndex++]
136
- : createIndexConstant (rewriter, loc, size));
132
+ sizes.push_back (
133
+ size == ShapedType::kDynamic
134
+ ? dynamicSizes[dynamicIndex++]
135
+ : createIndexAttrConstant (rewriter, loc, indexType, size));
137
136
}
138
137
139
138
// Strides: iterate sizes in reverse order and multiply.
140
139
int64_t stride = 1 ;
141
- Value runningStride = createIndexConstant (rewriter, loc, 1 );
140
+ Value runningStride = createIndexAttrConstant (rewriter, loc, indexType , 1 );
142
141
strides.resize (memRefType.getRank ());
143
142
for (auto i = memRefType.getRank (); i-- > 0 ;) {
144
143
strides[i] = runningStride;
@@ -158,7 +157,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
158
157
runningStride =
159
158
rewriter.create <LLVM::MulOp>(loc, runningStride, sizes[i]);
160
159
else
161
- runningStride = createIndexConstant (rewriter, loc, stride);
160
+ runningStride = createIndexAttrConstant (rewriter, loc, indexType , stride);
162
161
}
163
162
if (sizeInBytes) {
164
163
// Buffer size in bytes.
@@ -195,22 +194,25 @@ Value ConvertToLLVMPattern::getNumElements(
195
194
static_cast <ssize_t >(dynamicSizes.size ()) &&
196
195
" dynamicSizes size doesn't match dynamic sizes count in memref shape" );
197
196
197
+ Type indexType = getIndexType ();
198
198
Value numElements = memRefType.getRank () == 0
199
- ? createIndexConstant (rewriter, loc, 1 )
199
+ ? createIndexAttrConstant (rewriter, loc, indexType , 1 )
200
200
: nullptr ;
201
201
unsigned dynamicIndex = 0 ;
202
202
203
203
// Compute the total number of memref elements.
204
204
for (int64_t staticSize : memRefType.getShape ()) {
205
205
if (numElements) {
206
- Value size = staticSize == ShapedType::kDynamic
207
- ? dynamicSizes[dynamicIndex++]
208
- : createIndexConstant (rewriter, loc, staticSize);
206
+ Value size =
207
+ staticSize == ShapedType::kDynamic
208
+ ? dynamicSizes[dynamicIndex++]
209
+ : createIndexAttrConstant (rewriter, loc, indexType, staticSize);
209
210
numElements = rewriter.create <LLVM::MulOp>(loc, numElements, size);
210
211
} else {
211
- numElements = staticSize == ShapedType::kDynamic
212
- ? dynamicSizes[dynamicIndex++]
213
- : createIndexConstant (rewriter, loc, staticSize);
212
+ numElements =
213
+ staticSize == ShapedType::kDynamic
214
+ ? dynamicSizes[dynamicIndex++]
215
+ : createIndexAttrConstant (rewriter, loc, indexType, staticSize);
214
216
}
215
217
}
216
218
return numElements;
@@ -231,8 +233,9 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
231
233
memRefDescriptor.setAlignedPtr (rewriter, loc, alignedPtr);
232
234
233
235
// Field 3: Offset in aligned pointer.
234
- memRefDescriptor.setOffset (rewriter, loc,
235
- createIndexConstant (rewriter, loc, 0 ));
236
+ Type indexType = getIndexType ();
237
+ memRefDescriptor.setOffset (
238
+ rewriter, loc, createIndexAttrConstant (rewriter, loc, indexType, 0 ));
236
239
237
240
// Fields 4: Sizes.
238
241
for (const auto &en : llvm::enumerate (sizes))
0 commit comments