Skip to content

Commit b8bc3cd

Browse files
committed
PR FEEDBACK: remove defaultMemorySpace
1 parent 6283d23 commit b8bc3cd

File tree

8 files changed

+31
-31
lines changed

8 files changed

+31
-31
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ struct BufferizationOptions {
258258
using UnknownTypeConverterFn = std::function<BaseMemRefType(
259259
Value, Attribute memorySpace, const BufferizationOptions &)>;
260260
// Produce a MemorySpace attribute from a tensor type
261-
using GetMemorySpaceFn =
261+
using DefaultMemorySpaceFn =
262262
std::function<std::optional<Attribute>(TensorType t)>;
263263

264264
BufferizationOptions();
@@ -299,11 +299,6 @@ struct BufferizationOptions {
299299
/// bufferized or not.
300300
bool bufferizeFunctionBoundaries = false;
301301

302-
/// The default memory space that should be used when it cannot be inferred
303-
/// from the context. If case of std::nullopt, bufferization fails when the
304-
/// memory space cannot be inferred at any point.
305-
std::optional<Attribute> defaultMemorySpace = Attribute();
306-
307302
/// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for).
308303
/// If this flag is set to `false`, those invariants are no longer enforced
309304
/// with buffer copies.
@@ -355,14 +350,9 @@ struct BufferizationOptions {
355350
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
356351

357352
// Use during type conversion to determine the memory space for memref based
358-
// on the originanl tensor type
359-
GetMemorySpaceFn getMemorySpaceFn = nullptr;
360-
361-
std::optional<Attribute> getMemorySpace(TensorType t) const {
362-
if (getMemorySpaceFn)
363-
return getMemorySpaceFn(t);
364-
return defaultMemorySpace;
365-
}
353+
// on the originanl tensor type if the memory space cannot be inferred.
354+
DefaultMemorySpaceFn defaultMemorySpaceFn =
355+
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
366356

367357
/// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
368358
/// Should be used only with `testAnalysisOnly = true`.

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ struct ConstantOpInterface
3333
return failure();
3434

3535
Attribute memorySpace;
36-
if (options.getMemorySpace(type))
37-
memorySpace = *options.getMemorySpace(type);
36+
if (auto memSpace = options.defaultMemorySpaceFn(type))
37+
memorySpace = *memSpace;
3838
else
3939
return constantOp->emitError("could not infer memory space");
4040

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
682682
return bufferizableOp.getBufferType(value, options, invocationStack);
683683

684684
// Op is not bufferizable.
685-
auto memSpace = options.getMemorySpace(value.getType().cast<TensorType>());
685+
auto memSpace =
686+
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
686687
if (!memSpace.has_value())
687688
return op->emitError("could not infer memory space");
688689

@@ -936,7 +937,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
936937

937938
// If we do not know the memory space and there is no default memory space,
938939
// report a failure.
939-
auto memSpace = options.getMemorySpace(value.getType().cast<TensorType>());
940+
auto memSpace =
941+
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
940942
if (!memSpace.has_value())
941943
return op->emitError("could not infer memory space");
942944

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
234234
if (failed(copyBufferType))
235235
return failure();
236236
memorySpace = copyBufferType->getMemorySpace();
237-
} else if (auto x = options.getMemorySpace(getType()); x.has_value()) {
237+
} else if (auto x = options.defaultMemorySpaceFn(getType()); x.has_value()) {
238238
memorySpace = *x;
239239
} else {
240240
return getOperation()->emitError("could not infer memory space");

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,12 @@ struct OneShotBufferizePass
210210
opt.dumpAliasSets = dumpAliasSets;
211211
opt.setFunctionBoundaryTypeConversion(
212212
parseLayoutMapOption(functionBoundaryTypeConversion));
213-
if (mustInferMemorySpace)
214-
opt.defaultMemorySpace = std::nullopt;
213+
if (mustInferMemorySpace) {
214+
opt.defaultMemorySpaceFn =
215+
[](TensorType t) -> std::optional<Attribute> {
216+
return std::nullopt;
217+
};
218+
}
215219
opt.printConflicts = printConflicts;
216220
opt.testAnalysisOnly = testAnalysisOnly;
217221
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
6666
assert(tensorType && "expected TensorType");
6767

6868
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
69-
tensorType, *options.getMemorySpace(tensorType), funcOp, options);
69+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
7070

7171
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
7272
index, BufferizationDialect::kBufferLayoutAttrName);
@@ -443,7 +443,8 @@ struct FuncOpInterface
443443
// Note: If `inferFunctionResultLayout = true`, cast are later folded
444444
// away.
445445
BaseMemRefType resultType = options.functionArgTypeConverterFn(
446-
tensorType, *options.getMemorySpace(tensorType), funcOp, options);
446+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
447+
options);
447448
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
448449
loc, resultType, returnVal);
449450
returnValues.push_back(toMemrefOp);

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ struct FromElementsOpInterface
476476
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
477477

478478
// TODO: Implement memory space for this op.
479-
if (options.getMemorySpace(tensorType) != Attribute())
479+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
480480
return op->emitError("memory space not implemented yet");
481481

482482
// Allocate a buffer for the result.
@@ -591,7 +591,7 @@ struct GenerateOpInterface
591591
auto type = generateOp.getResult().getType();
592592

593593
// TODO: Implement memory space for this op.
594-
if (options.getMemorySpace(type) != Attribute())
594+
if (options.defaultMemorySpaceFn(type) != Attribute())
595595
return op->emitError("memory space not implemented yet");
596596

597597
// Allocate memory.
@@ -1009,10 +1009,6 @@ struct SplatOpInterface
10091009
OpBuilder::InsertionGuard g(rewriter);
10101010
auto splatOp = cast<tensor::SplatOp>(op);
10111011

1012-
// TODO: Implement memory space for this op.
1013-
if (options.defaultMemorySpace != Attribute())
1014-
return op->emitError("memory space not implemented yet");
1015-
10161012
// Allocate memory.
10171013
Location loc = op->getLoc();
10181014
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -1023,6 +1019,11 @@ struct SplatOpInterface
10231019

10241020
// Create linalg::MapOp.
10251021
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1022+
1023+
// TODO: Implement memory space for this op.
1024+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1025+
return op->emitError("memory space not implemented yet");
1026+
10261027
auto linalgOp =
10271028
rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
10281029
/*init=*/*tensorAlloc);

mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ struct TestTensorCopyInsertionPass
4444
bufferization::OneShotBufferizationOptions options;
4545
options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
4646
options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
47-
if (mustInferMemorySpace)
48-
options.defaultMemorySpace = std::nullopt;
47+
if (mustInferMemorySpace) {
48+
options.defaultMemorySpaceFn =
49+
[](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
50+
}
4951
if (failed(bufferization::insertTensorCopies(getOperation(), options)))
5052
signalPassFailure();
5153
}

0 commit comments

Comments
 (0)