12
12
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
13
13
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
14
14
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15
- include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
16
15
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
17
16
include "mlir/Interfaces/DestinationStyleOpInterface.td"
18
17
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,31 +386,20 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
387
386
// ToTensorOp
388
387
//===----------------------------------------------------------------------===//
389
388
390
- class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391
- "specified tensor and buffer types match",
392
- CPred<
393
- "::mlir::bufferization::detail::typesMatchAfterBufferization("
394
- "$_op, $" # tensor # ", $" # buffer #")"
395
- >
396
- >;
397
-
398
389
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
399
390
BufferizableOpInterface,
400
391
SameOperandsAndResultShape,
401
392
SameOperandsAndResultElementType,
402
- Bufferization_TensorAndBufferMatch<"result ", "buffer" >
393
+ AllElementTypesMatch<["memref ", "result"] >
403
394
]> {
404
- let summary = "create a buffer-like type from a tensor-like type ";
395
+ let summary = "create a tensor from a `memref` ";
405
396
let description = [{
406
- An operation that creates a tensor from a buffer. The result value is a
407
- tensor-like type that must match the corresponding buffer-like operand as
408
- per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409
- and BaseMemRefType), this means that shapes and element types match between
410
- the tensor and the buffer.
397
+ An operation that creates a tensor from a `memref`. The result value is a
398
+ tensor whose shape and element type match the memref operand.
411
399
412
400
The opposite of this op is `to_buffer`. Together, these two ops are
413
401
useful for source/target materializations when doing type conversions
414
- involving tensors and buffers .
402
+ involving tensors and memrefs .
415
403
416
404
Example:
417
405
@@ -453,16 +441,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
453
441
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
454
442
}];
455
443
456
- let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface ,
444
+ let arguments = (ins Arg<AnyRankedOrUnrankedMemRef ,
457
445
"the reference to load from",
458
- [MemReadAt<0, FullEffect>]>:$buffer ,
446
+ [MemReadAt<0, FullEffect>]>:$memref ,
459
447
UnitAttr:$restrict, UnitAttr:$writable);
460
- let results = (outs Bufferization_TensorLikeTypeInterface :$result);
448
+ let results = (outs AnyTensor :$result);
461
449
462
450
let extraClassDeclaration = [{
463
451
/// The result of a to_tensor is always a tensor.
464
- ::mlir::bufferization::TensorLikeType getType() {
465
- return getResult().getType();
452
+ TensorType getType() {
453
+ Type resultType = getResult().getType();
454
+ if (::llvm::isa<TensorType>(resultType))
455
+ return ::llvm::cast<TensorType>(resultType);
456
+ return {};
466
457
}
467
458
468
459
//===------------------------------------------------------------------===//
@@ -481,15 +472,22 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
481
472
FailureOr<BaseMemRefType> getBufferType(
482
473
Value value, const BufferizationOptions &options,
483
474
const BufferizationState &state, SmallVector<Value> &invocationStack) {
484
- return ::llvm::cast<BaseMemRefType>(getBuffer ().getType());
475
+ return ::llvm::cast<BaseMemRefType>(getMemref ().getType());
485
476
}
486
477
}];
487
478
488
479
let assemblyFormat = [{
489
- $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490
- `:` type($buffer ) `to` type($result)
480
+ $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481
+ `:` type($memref ) `to` type($result)
491
482
}];
492
483
484
+ let builders = [
485
+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486
+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487
+ build($_builder, $_state, rtt, memref, restrict, writeable);
488
+ }]>
489
+ ];
490
+
493
491
let hasCanonicalizer = 1;
494
492
let hasFolder = 1;
495
493
}
@@ -504,9 +502,10 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
504
502
SameOperandsAndResultShape,
505
503
SameOperandsAndResultElementType,
506
504
Pure,
507
- Bufferization_TensorAndBufferMatch<"tensor", "buffer">
505
+ AllShapesMatch<["memref", "tensor"]>,
506
+ AllElementTypesMatch<["memref", "tensor"]>
508
507
]> {
509
- let summary = "cast a tensor-like type to buffer-like type ";
508
+ let summary = "cast a tensor to memref ";
510
509
let description = [{
511
510
An operation that returns the future buffer of a `tensor`.
512
511
@@ -524,8 +523,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
524
523
the returned buffer) will not be written to.
525
524
}];
526
525
527
- let arguments = (ins Bufferization_TensorLikeTypeInterface :$tensor, UnitAttr:$read_only);
528
- let results = (outs Bufferization_BufferLikeTypeInterface:$buffer );
526
+ let arguments = (ins AnyTensor :$tensor, UnitAttr:$read_only);
527
+ let results = (outs AnyRankedOrUnrankedMemRef:$memref );
529
528
530
529
let extraClassDeclaration = [{
531
530
//===------------------------------------------------------------------===//
@@ -560,7 +559,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
560
559
}];
561
560
562
561
let assemblyFormat = [{
563
- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer )
562
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref )
564
563
}];
565
564
566
565
let hasFolder = 1;
0 commit comments