Skip to content

Commit 0d09249

Browse files
bjacoblialan
authored andcommitted
Revert "[mlir][bufferization] Support custom types (1/N) (llvm#142986)"
This reverts commit ee070d0.
1 parent d5a79ad commit 0d09249

27 files changed

+135
-388
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <optional>
1818

1919
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
20-
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
2120

2221
namespace mlir {
2322
class OpBuilder;
@@ -616,7 +615,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
616615
/// IR, this function can be used.
617616
///
618617
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
619-
FailureOr<BufferLikeType> getBufferType(Value value,
618+
FailureOr<BaseMemRefType> getBufferType(Value value,
620619
const BufferizationOptions &options,
621620
const BufferizationState &state);
622621

@@ -630,7 +629,7 @@ FailureOr<BufferLikeType> getBufferType(Value value,
630629
/// IR, this function can be used.
631630
///
632631
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
633-
FailureOr<BufferLikeType> getBufferType(Value value,
632+
FailureOr<BaseMemRefType> getBufferType(Value value,
634633
const BufferizationOptions &options,
635634
const BufferizationState &state,
636635
SmallVector<Value> &invocationStack);
@@ -740,19 +739,6 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
740739
/// This is the default implementation of
741740
/// BufferizableOpInterface::hasTensorSemantics
742741
bool defaultHasTensorSemantics(Operation *op);
743-
744-
/// This is a helper function used when buffer type is guaranteed to be memref.
745-
/// It performs two actions: failure state checking and an explicit llvm::cast<>
746-
/// from the buffer-like type interface to a BaseMemRefType. This allows easier
747-
/// management of differences in C++ types at the API boundaries. Valid buffer
748-
/// type is casted to the memref type. Otherwise, the failure state is
749-
/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure().
750-
FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);
751-
752-
/// This function is a free-standing helper that relies on
753-
/// bufferization::TensorLikeTypeInterface to verify the types in tensor and
754-
/// buffer worlds match.
755-
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
756742
} // namespace detail
757743

758744
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15-
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1615
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1716
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1817
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,31 +386,20 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
387386
// ToTensorOp
388387
//===----------------------------------------------------------------------===//
389388

390-
class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391-
"specified tensor and buffer types match",
392-
CPred<
393-
"::mlir::bufferization::detail::typesMatchAfterBufferization("
394-
"$_op, $" # tensor # ", $" # buffer #")"
395-
>
396-
>;
397-
398389
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
399390
BufferizableOpInterface,
400391
SameOperandsAndResultShape,
401392
SameOperandsAndResultElementType,
402-
Bufferization_TensorAndBufferMatch<"result", "buffer">
393+
AllElementTypesMatch<["memref", "result"]>
403394
]> {
404-
let summary = "create a buffer-like type from a tensor-like type";
395+
let summary = "create a tensor from a `memref`";
405396
let description = [{
406-
An operation that creates a tensor from a buffer. The result value is a
407-
tensor-like type that must match the corresponding buffer-like operand as
408-
per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409-
and BaseMemRefType), this means that shapes and element types match between
410-
the tensor and the buffer.
397+
An operation that creates a tensor from a `memref`. The result value is a
398+
tensor whose shape and element type match the memref operand.
411399

412400
The opposite of this op is `to_buffer`. Together, these two ops are
413401
useful for source/target materializations when doing type conversions
414-
involving tensors and buffers.
402+
involving tensors and memrefs.
415403

416404
Example:
417405

@@ -453,16 +441,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
453441
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
454442
}];
455443

456-
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
444+
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
457445
"the reference to load from",
458-
[MemReadAt<0, FullEffect>]>:$buffer,
446+
[MemReadAt<0, FullEffect>]>:$memref,
459447
UnitAttr:$restrict, UnitAttr:$writable);
460-
let results = (outs Bufferization_TensorLikeTypeInterface:$result);
448+
let results = (outs AnyTensor:$result);
461449

462450
let extraClassDeclaration = [{
463451
/// The result of a to_tensor is always a tensor.
464-
::mlir::bufferization::TensorLikeType getType() {
465-
return getResult().getType();
452+
TensorType getType() {
453+
Type resultType = getResult().getType();
454+
if (::llvm::isa<TensorType>(resultType))
455+
return ::llvm::cast<TensorType>(resultType);
456+
return {};
466457
}
467458

468459
//===------------------------------------------------------------------===//
@@ -481,15 +472,22 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
481472
FailureOr<BaseMemRefType> getBufferType(
482473
Value value, const BufferizationOptions &options,
483474
const BufferizationState &state, SmallVector<Value> &invocationStack) {
484-
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
475+
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
485476
}
486477
}];
487478

488479
let assemblyFormat = [{
489-
$buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490-
`:` type($buffer) `to` type($result)
480+
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481+
`:` type($memref) `to` type($result)
491482
}];
492483

484+
let builders = [
485+
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486+
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487+
build($_builder, $_state, rtt, memref, restrict, writeable);
488+
}]>
489+
];
490+
493491
let hasCanonicalizer = 1;
494492
let hasFolder = 1;
495493
}
@@ -504,9 +502,10 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
504502
SameOperandsAndResultShape,
505503
SameOperandsAndResultElementType,
506504
Pure,
507-
Bufferization_TensorAndBufferMatch<"tensor", "buffer">
505+
AllShapesMatch<["memref", "tensor"]>,
506+
AllElementTypesMatch<["memref", "tensor"]>
508507
]> {
509-
let summary = "cast a tensor-like type to buffer-like type";
508+
let summary = "cast a tensor to memref";
510509
let description = [{
511510
An operation that returns the future buffer of a `tensor`.
512511

@@ -524,8 +523,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
524523
the returned buffer) will not be written to.
525524
}];
526525

527-
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
528-
let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
526+
let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
527+
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
529528

530529
let extraClassDeclaration = [{
531530
//===------------------------------------------------------------------===//
@@ -560,7 +559,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
560559
}];
561560

562561
let assemblyFormat = [{
563-
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
562+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
564563
}];
565564

566565
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,8 @@
1313
// Bufferization Type Interfaces
1414
//===----------------------------------------------------------------------===//
1515

16-
#include "mlir/IR/Diagnostics.h"
1716
#include "mlir/IR/Types.h"
1817

19-
namespace mlir::bufferization {
20-
struct BufferizationOptions;
21-
class BufferizationState;
22-
class BufferLikeType;
23-
} // namespace mlir::bufferization
24-
2518
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
2619

2720
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,10 @@ def Bufferization_TensorLikeTypeInterface
2121
let description = [{
2222
Indicates that this type is a tensor type (similarly to a MLIR builtin
2323
tensor) for bufferization purposes.
24-
}];
2524

26-
let methods = [
27-
InterfaceMethod<[{
28-
Returns a BufferLike type for this TensorLike type.
29-
}],
30-
/*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
31-
/*methodName=*/"getBufferType",
32-
/*args=*/(ins
33-
"const ::mlir::bufferization::BufferizationOptions &":$options,
34-
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
35-
)
36-
>,
37-
InterfaceMethod<[{
38-
Returns whether a BufferLike type is compatible to this TensorLike type.
39-
The BufferLike type is assumed to be created by getBufferType().
40-
}],
41-
/*retTy=*/"::mlir::LogicalResult",
42-
/*methodName=*/"verifyCompatibleBufferType",
43-
/*args=*/(ins
44-
"::mlir::bufferization::BufferLikeType":$bufferType,
45-
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
46-
>
47-
];
25+
The interface currently has no methods as it is used by types to opt into
26+
being supported by the bufferization procedures.
27+
}];
4828
}
4929

5030
def Bufferization_BufferLikeTypeInterface

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,12 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
6565
// The operand was already bufferized. Take its type directly.
6666
callerType = memrefType;
6767
} else {
68-
FailureOr<BufferLikeType> maybeCallerType =
68+
FailureOr<BaseMemRefType> maybeCallerType =
6969
bufferization::getBufferType(opOperand->get(), options, state,
7070
invocationStack);
7171
if (failed(maybeCallerType))
7272
return failure();
73-
assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
74-
callerType = cast<BaseMemRefType>(*maybeCallerType);
73+
callerType = *maybeCallerType;
7574
}
7675

7776
if (!bufferType) {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ struct SelectOpInterface
164164
// buffers have different types, they differ only in their layout map. Cast
165165
// both of them to the most dynamic MemRef type.
166166
if (trueBuffer.getType() != falseBuffer.getType()) {
167-
auto targetType = bufferization::detail::asMemRefType(
168-
bufferization::getBufferType(selectOp.getResult(), options, state));
167+
auto targetType =
168+
bufferization::getBufferType(selectOp.getResult(), options, state);
169169
if (failed(targetType))
170170
return failure();
171171
if (trueBuffer.getType() != *targetType)
@@ -187,12 +187,10 @@ struct SelectOpInterface
187187
SmallVector<Value> &invocationStack) const {
188188
auto selectOp = cast<arith::SelectOp>(op);
189189
assert(value == selectOp.getResult() && "invalid value");
190-
auto trueType =
191-
bufferization::detail::asMemRefType(bufferization::getBufferType(
192-
selectOp.getTrueValue(), options, state, invocationStack));
193-
auto falseType =
194-
bufferization::detail::asMemRefType(bufferization::getBufferType(
195-
selectOp.getFalseValue(), options, state, invocationStack));
190+
auto trueType = bufferization::getBufferType(
191+
selectOp.getTrueValue(), options, state, invocationStack);
192+
auto falseType = bufferization::getBufferType(
193+
selectOp.getFalseValue(), options, state, invocationStack);
196194
if (failed(trueType) || failed(falseType))
197195
return failure();
198196
if (*trueType == *falseType)

0 commit comments

Comments
 (0)