@@ -69,6 +69,30 @@ struct Helper final {
69
69
rewriter.getIntegerAttr (idxType, static_cast <int64_t >(value)));
70
70
}
71
71
72
+ Value calculateStaticSize (OpBuilder &rewriter, const Location loc,
73
+ const MemRefType type) const {
74
+ if (type.getRank () == 0 ) {
75
+ return idxConstant (rewriter, loc, 0 );
76
+ }
77
+
78
+ auto elementType = type.getElementType ();
79
+ if (!elementType.isIntOrIndexOrFloat ()) {
80
+ return nullptr ;
81
+ }
82
+
83
+ int64_t numElements = 1 ;
84
+ for (auto dim : type.getShape ()) {
85
+ if (dim == ShapedType::kDynamic ) {
86
+ return nullptr ;
87
+ }
88
+ numElements = numElements * dim;
89
+ }
90
+ auto elementSize = elementType.isIndex ()
91
+ ? idxType.getIntOrFloatBitWidth ()
92
+ : elementType.getIntOrFloatBitWidth ();
93
+ return idxConstant (rewriter, loc, elementSize * numElements / 8 );
94
+ }
95
+
72
96
void destroyKernels (OpBuilder &rewriter, Location loc,
73
97
ArrayRef<Value> kernelPtrs) const {
74
98
auto size = idxConstant (rewriter, loc, kernelPtrs.size ());
@@ -102,82 +126,44 @@ struct ConvertAlloc final : ConvertOpPattern<gpu::AllocOp> {
102
126
ConversionPatternRewriter &rewriter) const override {
103
127
auto loc = allocOp.getLoc ();
104
128
MemRefType type = allocOp.getType ();
105
- auto shape = type.getShape ();
106
- auto dynamics = adaptor.getDynamicSizes ();
107
129
108
- if (shape.empty () || dynamics.empty ()) {
109
- int64_t staticSize;
110
- if (shape.empty ()) {
111
- staticSize = 0 ;
112
- } else {
113
- staticSize = type.getElementType ().getIntOrFloatBitWidth () / 8 ;
114
- for (auto dim : shape) {
115
- assert (dim != ShapedType::kDynamic );
116
- staticSize *= dim;
117
- }
118
- }
119
- auto size = helper.idxConstant (rewriter, loc, staticSize);
130
+ if (auto staticSize = helper.calculateStaticSize (rewriter, loc, type)) {
120
131
auto ptr = funcCall (rewriter, GPU_OCL_MALLOC, helper.ptrType ,
121
132
{helper.ptrType , helper.idxType }, loc,
122
- {getCtxPtr (rewriter), size })
133
+ {getCtxPtr (rewriter), staticSize })
123
134
.getResult ();
124
135
Value replacement = MemRefDescriptor::fromStaticShape (
125
136
rewriter, loc, helper.converter , type, ptr, ptr);
126
137
rewriter.replaceOp (allocOp, replacement);
127
138
return success ();
128
139
}
129
140
130
- auto ndims = shape.size ();
131
- SmallVector<Value> newShape;
132
- SmallVector<Value> newStrides (ndims);
133
- auto staticSize = type.getElementType ().getIntOrFloatBitWidth () / 8 ;
134
- auto size = dynamics[0 ];
135
-
136
- auto idxMul = [&](Value x, Value y) -> Value {
137
- if (auto xConst = getConstantIntValue (x)) {
138
- if (auto yConst = getConstantIntValue (y)) {
139
- return helper.idxConstant (rewriter, loc,
140
- xConst.value () * yConst.value ());
141
- }
142
- }
143
- return rewriter.create <LLVM::MulOp>(loc, x, y);
144
- };
145
-
146
- for (size_t i = 0 , j = 0 ; i < ndims; i++) {
147
- auto dim = shape[i];
148
- if (dim == ShapedType::kDynamic ) {
149
- auto dynSize = dynamics[j++];
150
- newShape.emplace_back (dynSize);
151
- if (j != 1 ) {
152
- size = idxMul (size, dynSize);
153
- }
154
- } else {
155
- staticSize *= dim;
156
- newShape.emplace_back (helper.idxConstant (rewriter, loc, dim));
157
- }
141
+ auto dstType = helper.converter .convertType (type);
142
+ if (!dstType) {
143
+ allocOp.emitError () << " Failed to convert the MemRefType" ;
144
+ return failure ();
158
145
}
159
146
160
- size = idxMul (size, helper.idxConstant (rewriter, loc, staticSize));
147
+ SmallVector<Value> shape;
148
+ SmallVector<Value> strides;
149
+ Value size;
150
+ getMemRefDescriptorSizes (loc, type, adaptor.getDynamicSizes (), rewriter,
151
+ shape, strides, size);
152
+ assert (shape.size () == strides.size ());
153
+
161
154
auto ptr = funcCall (rewriter, GPU_OCL_MALLOC, helper.ptrType ,
162
155
{helper.ptrType , helper.idxType }, loc,
163
156
{getCtxPtr (rewriter), size})
164
157
.getResult ();
165
158
166
- newStrides[ndims - 1 ] = helper.idxConstant (rewriter, loc, 1 );
167
- for (int i = static_cast <int >(ndims) - 2 ; i >= 0 ; i--) {
168
- newStrides[i] = idxMul (newStrides[i + 1 ], newShape[i]);
169
- ;
170
- }
171
-
172
- auto dsc = MemRefDescriptor::undef (rewriter, loc,
173
- helper.converter .convertType (type));
159
+ auto dsc = MemRefDescriptor::undef (rewriter, loc, dstType);
174
160
dsc.setAllocatedPtr (rewriter, loc, ptr);
175
161
dsc.setAlignedPtr (rewriter, loc, ptr);
176
162
dsc.setOffset (rewriter, loc, helper.idxConstant (rewriter, loc, 0 ));
177
163
178
- for (unsigned i = 0 , n = static_cast <unsigned >(ndims ); i < n; i++) {
179
- dsc.setSize (rewriter, loc, i, newShape [i]);
180
- dsc.setStride (rewriter, loc, i, newStrides [i]);
164
+ for (unsigned i = 0 , n = static_cast <unsigned >(shape. size () ); i < n; i++) {
165
+ dsc.setSize (rewriter, loc, i, shape [i]);
166
+ dsc.setStride (rewriter, loc, i, strides [i]);
181
167
}
182
168
183
169
rewriter.replaceOp (allocOp, static_cast <Value>(dsc));
@@ -209,23 +195,24 @@ struct ConvertMemcpy final : ConvertOpPattern<gpu::MemcpyOp> {
209
195
matchAndRewrite (gpu::MemcpyOp gpuMemcpy, OpAdaptor adaptor,
210
196
ConversionPatternRewriter &rewriter) const override {
211
197
auto loc = gpuMemcpy.getLoc ();
198
+ MemRefDescriptor srcDsc (adaptor.getSrc ());
199
+ MemRefDescriptor dstDsc (adaptor.getDst ());
212
200
auto srcType = gpuMemcpy.getSrc ().getType ();
213
- auto elementSize = srcType. getElementType (). getIntOrFloatBitWidth () / 8 ;
214
- uint64_t numElements = 0 ;
215
- for ( auto dim : srcType. getShape () ) {
216
- if (dim == ShapedType:: kDynamic ) {
217
- gpuMemcpy. emitOpError ()
218
- << " dynamic shapes are not currently not supported " ;
219
- return failure ( );
201
+ Value size = helper. calculateStaticSize (rewriter, loc, srcType) ;
202
+
203
+ if (!size ) {
204
+ auto numElements = helper. idxConstant (rewriter, loc, 1 );
205
+ for ( unsigned i = 0 , n = srcType. getRank (); i < n; i++) {
206
+ numElements = rewriter. create <LLVM::MulOp>(
207
+ loc, numElements, srcDsc. size (rewriter, loc, i) );
220
208
}
221
- numElements = numElements ? numElements * dim : dim;
209
+ size = rewriter.create <mlir::LLVM::MulOp>(
210
+ loc, numElements,
211
+ getSizeInBytes (loc, srcType.getElementType (), rewriter));
222
212
}
223
213
224
- MemRefDescriptor srcDsc (adaptor.getSrc ());
225
- MemRefDescriptor dstDsc (adaptor.getDst ());
226
214
auto srcPtr = srcDsc.alignedPtr (rewriter, loc);
227
215
auto dstPtr = dstDsc.alignedPtr (rewriter, loc);
228
- auto size = helper.idxConstant (rewriter, loc, elementSize * numElements);
229
216
auto oclMemcpy = funcCall (
230
217
rewriter, GPU_OCL_MEMCPY, helper.voidType ,
231
218
{helper.ptrType , helper.ptrType , helper.ptrType , helper.idxType }, loc,
0 commit comments