@@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
165
165
// / allocated.
166
166
FailureOr<Value> bufferization::allocateTensorForShapedValue (
167
167
OpBuilder &b, Location loc, Value shapedValue,
168
- const BufferizationOptions &options, bool copy) {
168
+ const BufferizationOptions &options, const BufferizationState &state,
169
+ bool copy) {
169
170
Value tensor;
170
171
if (llvm::isa<RankedTensorType>(shapedValue.getType ())) {
171
172
tensor = shapedValue;
@@ -210,7 +211,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
210
211
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
211
212
if (copy)
212
213
return allocTensorOp.getResult ();
213
- FailureOr<BaseMemRefType> copyBufferType = getBufferType (tensor, options);
214
+ FailureOr<BaseMemRefType> copyBufferType =
215
+ getBufferType (tensor, options, state);
214
216
if (failed (copyBufferType))
215
217
return failure ();
216
218
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace ();
@@ -222,7 +224,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
222
224
}
223
225
224
226
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts (
225
- RewriterBase &rewriter, const AnalysisState &state) {
227
+ RewriterBase &rewriter, const AnalysisState &analysisState,
228
+ const BufferizationState &bufferizationState) {
226
229
OpBuilder::InsertionGuard g (rewriter);
227
230
Operation *op = getOperation ();
228
231
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +238,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
235
238
Type operandType = opOperand.get ().getType ();
236
239
if (!llvm::isa<TensorType>(operandType))
237
240
continue ;
238
- if (state .isInPlace (opOperand))
241
+ if (analysisState .isInPlace (opOperand))
239
242
continue ;
240
243
if (llvm::isa<UnrankedTensorType>(operandType))
241
244
return op->emitError (" copying of unranked tensors is not implemented" );
242
245
243
- AliasingValueList aliasingValues = state.getAliasingValues (opOperand);
246
+ AliasingValueList aliasingValues =
247
+ analysisState.getAliasingValues (opOperand);
244
248
if (aliasingValues.getNumAliases () == 1 &&
245
249
isa<OpResult>(aliasingValues.getAliases ()[0 ].value ) &&
246
- !state.bufferizesToMemoryWrite (opOperand) &&
247
- state.getAliasingOpOperands (aliasingValues.getAliases ()[0 ].value )
250
+ !analysisState.bufferizesToMemoryWrite (opOperand) &&
251
+ analysisState
252
+ .getAliasingOpOperands (aliasingValues.getAliases ()[0 ].value )
248
253
.getNumAliases () == 1 &&
249
254
!isa<UnrankedTensorType>(
250
255
aliasingValues.getAliases ()[0 ].value .getType ())) {
@@ -256,12 +261,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
256
261
// cannot be copied at the moment).
257
262
Value value = aliasingValues.getAliases ()[0 ].value ;
258
263
outOfPlaceValues.push_back (value);
259
- if (!state .canOmitTensorCopy (opOperand))
264
+ if (!analysisState .canOmitTensorCopy (opOperand))
260
265
copiedOpValues.insert (value);
261
266
} else {
262
267
// In all other cases, make a copy of the OpOperand.
263
268
outOfPlaceOpOperands.push_back (&opOperand);
264
- if (!state .canOmitTensorCopy (opOperand))
269
+ if (!analysisState .canOmitTensorCopy (opOperand))
265
270
copiedOpOperands.insert (&opOperand);
266
271
}
267
272
}
@@ -270,8 +275,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
270
275
rewriter.setInsertionPoint (op);
271
276
for (OpOperand *opOperand : outOfPlaceOpOperands) {
272
277
FailureOr<Value> copy = allocateTensorForShapedValue (
273
- rewriter, op->getLoc (), opOperand->get (), state .getOptions (),
274
- copiedOpOperands.contains (opOperand));
278
+ rewriter, op->getLoc (), opOperand->get (), analysisState .getOptions (),
279
+ bufferizationState, copiedOpOperands.contains (opOperand));
275
280
if (failed (copy))
276
281
return failure ();
277
282
rewriter.modifyOpInPlace (op, [&]() { opOperand->set (*copy); });
@@ -281,8 +286,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
281
286
rewriter.setInsertionPointAfter (op);
282
287
for (Value value : outOfPlaceValues) {
283
288
FailureOr<Value> copy = allocateTensorForShapedValue (
284
- rewriter, op->getLoc (), value, state .getOptions (),
285
- copiedOpValues.count (value));
289
+ rewriter, op->getLoc (), value, analysisState .getOptions (),
290
+ bufferizationState, copiedOpValues.count (value));
286
291
if (failed (copy))
287
292
return failure ();
288
293
SmallVector<OpOperand *> uses = llvm::to_vector (
@@ -665,7 +670,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
665
670
}
666
671
667
672
FailureOr<Value> bufferization::getBuffer (RewriterBase &rewriter, Value value,
668
- const BufferizationOptions &options) {
673
+ const BufferizationOptions &options,
674
+ const BufferizationState &state) {
669
675
#ifndef NDEBUG
670
676
auto tensorType = llvm::dyn_cast<TensorType>(value.getType ());
671
677
assert (tensorType && " unexpected non-tensor type" );
@@ -678,7 +684,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
678
684
// Insert to_buffer op.
679
685
OpBuilder::InsertionGuard g (rewriter);
680
686
setInsertionPointAfter (rewriter, value);
681
- FailureOr<BaseMemRefType> memrefType = getBufferType (value, options);
687
+ FailureOr<BaseMemRefType> memrefType = getBufferType (value, options, state );
682
688
if (failed (memrefType))
683
689
return failure ();
684
690
ensureToBufferOpIsValid (value, *memrefType);
@@ -689,14 +695,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
689
695
690
696
// / Return the buffer type for a given Value (tensor) after bufferization.
691
697
FailureOr<BaseMemRefType>
692
- bufferization::getBufferType (Value value, const BufferizationOptions &options) {
698
+ bufferization::getBufferType (Value value, const BufferizationOptions &options,
699
+ const BufferizationState &state) {
693
700
SmallVector<Value> invocationStack;
694
- return getBufferType (value, options, invocationStack);
701
+ return getBufferType (value, options, state, invocationStack);
695
702
}
696
703
697
704
// / Return the buffer type for a given Value (tensor) after bufferization.
698
705
FailureOr<BaseMemRefType>
699
706
bufferization::getBufferType (Value value, const BufferizationOptions &options,
707
+ const BufferizationState &state,
700
708
SmallVector<Value> &invocationStack) {
701
709
assert (llvm::isa<TensorType>(value.getType ()) &&
702
710
" unexpected non-tensor type" );
@@ -708,7 +716,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
708
716
Operation *op = getOwnerOfValue (value);
709
717
auto bufferizableOp = options.dynCastBufferizableOp (op);
710
718
if (bufferizableOp)
711
- return bufferizableOp.getBufferType (value, options, invocationStack);
719
+ return bufferizableOp.getBufferType (value, options, state, invocationStack);
712
720
713
721
// Op is not bufferizable.
714
722
auto memSpace =
@@ -944,6 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
944
952
945
953
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType (
946
954
Value value, const BufferizationOptions &options,
955
+ const BufferizationState &bufferizationState,
947
956
SmallVector<Value> &invocationStack) {
948
957
assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
949
958
@@ -954,14 +963,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
954
963
// Value is an OpResult.
955
964
Operation *op = getOwnerOfValue (value);
956
965
auto opResult = llvm::cast<OpResult>(value);
957
- AnalysisState state (options);
958
- AliasingOpOperandList aliases = state .getAliasingOpOperands (opResult);
966
+ AnalysisState analysisState (options);
967
+ AliasingOpOperandList aliases = analysisState .getAliasingOpOperands (opResult);
959
968
if (aliases.getNumAliases () > 0 &&
960
969
aliases.getAliases ()[0 ].relation == BufferRelation::Equivalent) {
961
970
// If the OpResult has an equivalent OpOperand, both OpResult and
962
971
// OpOperand bufferize to the exact same buffer type.
963
972
Value equivalentOperand = aliases.getAliases ().front ().opOperand ->get ();
964
- return getBufferType (equivalentOperand, options, invocationStack);
973
+ return getBufferType (equivalentOperand, options, bufferizationState,
974
+ invocationStack);
965
975
}
966
976
967
977
// If we do not know the memory space and there is no default memory space,
0 commit comments