12
12
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
13
13
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
14
14
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15
+ include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
15
16
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
16
17
include "mlir/Interfaces/DestinationStyleOpInterface.td"
17
18
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -386,20 +387,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
386
387
// ToTensorOp
387
388
//===----------------------------------------------------------------------===//
388
389
390
+ class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391
+ "specified tensor and buffer types match",
392
+ CPred<
393
+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
394
+ "$_op, $" # tensor # ", $" # buffer #")"
395
+ >
396
+ >;
397
+
389
398
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
390
399
BufferizableOpInterface,
391
400
SameOperandsAndResultShape,
392
401
SameOperandsAndResultElementType,
393
- AllElementTypesMatch<["memref ", "result"] >
402
+ Bufferization_TensorAndBufferMatch<"result ", "buffer" >
394
403
]> {
395
- let summary = "create a tensor from a `memref` ";
404
+ let summary = "create a buffer-like type from a tensor-like type ";
396
405
let description = [{
397
- An operation that creates a tensor from a `memref`. The result value is a
398
- tensor whose shape and element type match the memref operand.
406
+ An operation that creates a tensor from a buffer. The result value is a
407
+ tensor-like type that must match the corresponding buffer-like operand as
408
+ per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409
+ and BaseMemRefType), this means that shapes and element types match between
410
+ the tensor and the buffer.
399
411
400
412
The opposite of this op is `to_buffer`. Together, these two ops are
401
413
useful for source/target materializations when doing type conversions
402
- involving tensors and memrefs .
414
+ involving tensors and buffers .
403
415
404
416
Example:
405
417
@@ -441,19 +453,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
441
453
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
442
454
}];
443
455
444
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef ,
456
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface ,
445
457
"the reference to load from",
446
- [MemReadAt<0, FullEffect>]>:$memref ,
458
+ [MemReadAt<0, FullEffect>]>:$buffer ,
447
459
UnitAttr:$restrict, UnitAttr:$writable);
448
- let results = (outs AnyTensor :$result);
460
+ let results = (outs Bufferization_TensorLikeTypeInterface :$result);
449
461
450
462
let extraClassDeclaration = [{
451
463
/// The result of a to_tensor is always a tensor.
452
- TensorType getType() {
453
- Type resultType = getResult().getType();
454
- if (::llvm::isa<TensorType>(resultType))
455
- return ::llvm::cast<TensorType>(resultType);
456
- return {};
464
+ ::mlir::bufferization::TensorLikeType getType() {
465
+ return getResult().getType();
457
466
}
458
467
459
468
//===------------------------------------------------------------------===//
@@ -472,22 +481,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
472
481
FailureOr<BaseMemRefType> getBufferType(
473
482
Value value, const BufferizationOptions &options,
474
483
const BufferizationState &state, SmallVector<Value> &invocationStack) {
475
- return ::llvm::cast<BaseMemRefType>(getMemref ().getType());
484
+ return ::llvm::cast<BaseMemRefType>(getBuffer ().getType());
476
485
}
477
486
}];
478
487
479
488
let assemblyFormat = [{
480
- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481
- `:` type($memref ) `to` type($result)
489
+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490
+ `:` type($buffer ) `to` type($result)
482
491
}];
483
492
484
- let builders = [
485
- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486
- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487
- build($_builder, $_state, rtt, memref, restrict, writeable);
488
- }]>
489
- ];
490
-
491
493
let hasCanonicalizer = 1;
492
494
let hasFolder = 1;
493
495
}
@@ -502,10 +504,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
502
504
SameOperandsAndResultShape,
503
505
SameOperandsAndResultElementType,
504
506
Pure,
505
- AllShapesMatch<["memref", "tensor"]>,
506
- AllElementTypesMatch<["memref", "tensor"]>
507
+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
507
508
]> {
508
- let summary = "cast a tensor to memref ";
509
+ let summary = "cast a tensor-like type to buffer-like type ";
509
510
let description = [{
510
511
An operation that returns the future buffer of a `tensor`.
511
512
@@ -523,8 +524,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
523
524
the returned buffer) will not be written to.
524
525
}];
525
526
526
- let arguments = (ins AnyTensor :$tensor, UnitAttr:$read_only);
527
- let results = (outs AnyRankedOrUnrankedMemRef:$memref );
527
+ let arguments = (ins Bufferization_TensorLikeTypeInterface :$tensor, UnitAttr:$read_only);
528
+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer );
528
529
529
530
let extraClassDeclaration = [{
530
531
//===------------------------------------------------------------------===//
@@ -559,7 +560,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
559
560
}];
560
561
561
562
let assemblyFormat = [{
562
- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref )
563
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer )
563
564
}];
564
565
565
566
let hasFolder = 1;
0 commit comments