@@ -139,21 +139,83 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
139
139
if (getOperation ()->getUses ().empty ())
140
140
return success ();
141
141
142
- FailureOr<Value> alloc = state.createAlloc (rewriter, getLoc (), getResult ());
142
+ Optional<bool > dealloc = llvm::None;
143
+ if (escape ().hasValue ())
144
+ dealloc = !*escape ();
145
+ FailureOr<Value> alloc =
146
+ state.createAlloc (rewriter, getLoc (), getResult (), dealloc);
143
147
if (failed (alloc))
144
148
return failure ();
149
+ if (copy ()) {
150
+ FailureOr<Value> copyValueBuffer = state.getBuffer (
151
+ rewriter, getOperation ()->getOpOperand (getNumOperands () - 1 ));
152
+ if (failed (copyValueBuffer))
153
+ return failure ();
154
+ if (failed (state.getOptions ().createMemCpy (rewriter, getLoc (),
155
+ *copyValueBuffer, *alloc)))
156
+ return failure ();
157
+ }
145
158
replaceOpWithBufferizedValues (rewriter, getOperation (), *alloc);
146
159
return success ();
147
160
}
148
161
162
+ bool AllocTensorOp::isMemoryWrite (OpResult opResult,
163
+ const AnalysisState &state) {
164
+ // AllocTensorOps do not write unless they have a `copy` value.
165
+ return static_cast <bool >(copy ());
166
+ }
167
+
168
+ bool AllocTensorOp::bufferizesToMemoryRead (OpOperand &opOperand,
169
+ const AnalysisState &state) {
170
+ assert (opOperand.getOperandNumber () == getNumOperands () - 1 &&
171
+ " expected copy operand" );
172
+ return true ;
173
+ }
174
+
175
+ bool AllocTensorOp::bufferizesToMemoryWrite (OpOperand &opOperand,
176
+ const AnalysisState &state) {
177
+ assert (opOperand.getOperandNumber () == getNumOperands () - 1 &&
178
+ " expected copy operand" );
179
+ return false ;
180
+ }
181
+
182
+ SmallVector<OpResult>
183
+ AllocTensorOp::getAliasingOpResult (OpOperand &opOperand,
184
+ const AnalysisState &state) {
185
+ // This is a new allocation. It does not alias with any other buffer.
186
+ return {};
187
+ }
188
+
149
189
LogicalResult AllocTensorOp::verify () {
150
- if (getType ().getNumDynamicDims () !=
151
- static_cast <int64_t >(dynamicSizes ().size ()))
190
+ if (copy () && !dynamicSizes ().empty ())
191
+ return emitError (" dynamic sizes not needed when copying a tensor" );
192
+ if (!copy () && getType ().getNumDynamicDims () !=
193
+ static_cast <int64_t >(dynamicSizes ().size ()))
152
194
return emitError (" expected " )
153
195
<< getType ().getNumDynamicDims () << " dynamic sizes" ;
196
+ if (copy () && copy ().getType () != getType ())
197
+ return emitError (" expected that `copy` and return type match" );
154
198
return success ();
155
199
}
156
200
201
+ void AllocTensorOp::build (OpBuilder &builder, OperationState &result,
202
+ RankedTensorType type, ValueRange dynamicSizes) {
203
+ build (builder, result, type, dynamicSizes, /* copy=*/ Value (),
204
+ /* escape=*/ BoolAttr ());
205
+ }
206
+
207
+ void AllocTensorOp::build (OpBuilder &builder, OperationState &result,
208
+ RankedTensorType type, ValueRange dynamicSizes,
209
+ Value copy) {
210
+ build (builder, result, type, dynamicSizes, copy, /* escape=*/ BoolAttr ());
211
+ }
212
+
213
+ void AllocTensorOp::build (OpBuilder &builder, OperationState &result,
214
+ RankedTensorType type, ValueRange dynamicSizes,
215
+ Value copy, bool escape) {
216
+ build (builder, result, type, dynamicSizes, copy, builder.getBoolAttr (escape));
217
+ }
218
+
157
219
namespace {
158
220
// / Change the type of the result of a `bufferization.alloc_tensor` by making
159
221
// / the result type statically sized along dimension that in the original
@@ -171,6 +233,8 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
171
233
172
234
LogicalResult matchAndRewrite (AllocTensorOp op,
173
235
PatternRewriter &rewriter) const override {
236
+ if (op.copy ())
237
+ return failure ();
174
238
SmallVector<int64_t > newShape = llvm::to_vector (op.getType ().getShape ());
175
239
SmallVector<Value> newDynamicSizes;
176
240
unsigned int dynValCounter = 0 ;
@@ -189,8 +253,9 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
189
253
newShape, op.getType ().getElementType (), op.getType ().getEncoding ());
190
254
if (newType == op.getType ())
191
255
return failure ();
192
- auto newOp =
193
- rewriter.create <AllocTensorOp>(op.getLoc (), newType, newDynamicSizes);
256
+ auto newOp = rewriter.create <AllocTensorOp>(
257
+ op.getLoc (), newType, newDynamicSizes, /* copy=*/ Value (),
258
+ /* escape=*/ op.escapeAttr ());
194
259
rewriter.replaceOpWithNewOp <tensor::CastOp>(op, op.getType (), newOp);
195
260
return success ();
196
261
}
@@ -207,8 +272,8 @@ struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
207
272
return failure ();
208
273
if (!allocTensorOp.getType ().isDynamicDim (*maybeConstantIndex))
209
274
return failure ();
210
- rewriter.replaceOp (dimOp,
211
- allocTensorOp.getDynamicSize (*maybeConstantIndex));
275
+ rewriter.replaceOp (
276
+ dimOp, allocTensorOp.getDynamicSize (rewriter, *maybeConstantIndex));
212
277
return success ();
213
278
}
214
279
};
@@ -224,14 +289,67 @@ LogicalResult AllocTensorOp::reifyResultShapes(
224
289
auto shapes = llvm::to_vector<4 >(llvm::map_range (
225
290
llvm::seq<int64_t >(0 , getType ().getRank ()), [&](int64_t dim) -> Value {
226
291
if (isDynamicDim (dim))
227
- return getDynamicSize (dim);
292
+ return getDynamicSize (builder, dim);
228
293
return builder.create <arith::ConstantIndexOp>(getLoc (),
229
294
getStaticSize (dim));
230
295
}));
231
296
reifiedReturnShapes.emplace_back (std::move (shapes));
232
297
return success ();
233
298
}
234
299
300
+ ParseResult AllocTensorOp::parse (OpAsmParser &parser, OperationState &result) {
301
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
302
+ if (parser.parseLParen () || parser.parseOperandList (dynamicSizesOperands) ||
303
+ parser.parseRParen ())
304
+ return failure ();
305
+ ParseResult copyKeyword = parser.parseOptionalKeyword (" copy" );
306
+ OpAsmParser::UnresolvedOperand copyOperand;
307
+ if (copyKeyword.succeeded ())
308
+ if (parser.parseLParen () || parser.parseOperand (copyOperand) ||
309
+ parser.parseRParen ())
310
+ return failure ();
311
+ if (parser.parseOptionalAttrDict (result.attributes ) || parser.parseColon ())
312
+ return failure ();
313
+
314
+ TensorType type;
315
+ if (parser.parseCustomTypeWithFallback (type))
316
+ return failure ();
317
+ result.addTypes (type);
318
+
319
+ Type indexType = parser.getBuilder ().getIndexType ();
320
+ if (parser.resolveOperands (dynamicSizesOperands, indexType, result.operands ))
321
+ return failure ();
322
+ if (copyKeyword.succeeded ())
323
+ if (parser.resolveOperand (copyOperand, type, result.operands ))
324
+ return failure ();
325
+ result.addAttribute (AllocTensorOp::getOperandSegmentSizeAttr (),
326
+ parser.getBuilder ().getI32VectorAttr (
327
+ {static_cast <int32_t >(dynamicSizesOperands.size ()),
328
+ static_cast <int32_t >(copyKeyword.succeeded ())}));
329
+ return success ();
330
+ }
331
+
332
+ void AllocTensorOp::print (OpAsmPrinter &p) {
333
+ p << " (" << dynamicSizes () << " )" ;
334
+ if (copy ())
335
+ p << " copy(" << copy () << " )" ;
336
+ p.printOptionalAttrDict ((*this )->getAttrs (), /* elidedAttrs=*/ {
337
+ AllocTensorOp::getOperandSegmentSizeAttr ()});
338
+ p << " : " ;
339
+ auto type = result ().getType ();
340
+ if (auto validType = type.dyn_cast <::mlir::TensorType>())
341
+ p.printStrippedAttrOrType (validType);
342
+ else
343
+ p << type;
344
+ }
345
+
346
+ Value AllocTensorOp::getDynamicSize (OpBuilder &b, unsigned idx) {
347
+ assert (isDynamicDim (idx) && " expected dynamic dim" );
348
+ if (copy ())
349
+ return b.create <tensor::DimOp>(getLoc (), copy (), idx);
350
+ return getOperand (getIndexOfDynamicSize (idx));
351
+ }
352
+
235
353
// ===----------------------------------------------------------------------===//
236
354
// CloneOp
237
355
// ===----------------------------------------------------------------------===//
0 commit comments