13
13
#ifndef LINALG_IR_LINALGINTERFACES
14
14
#define LINALG_IR_LINALGINTERFACES
15
15
16
+ include "mlir/Interfaces/DestinationStyleOpInterface.td"
16
17
include "mlir/IR/OpBase.td"
17
18
18
19
// The 'LinalgContractionOpInterface' provides access to the
@@ -178,7 +179,8 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
178
179
}
179
180
180
181
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
181
- def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
182
+ def LinalgStructuredInterface
183
+ : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
182
184
let cppNamespace = "::mlir::linalg";
183
185
let methods = [
184
186
//===------------------------------------------------------------------===//
@@ -321,13 +323,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
321
323
/*args=*/(ins),
322
324
/*methodBody=*/"",
323
325
/*defaultImplementation=*/[{
324
- // MLIR currently does not support dependent interfaces or interface
325
- // inheritance. By construction all ops with StructuredOpInterface must
326
- // implement DestinationStyleOpInterface.
327
- // TODO: reevaluate the need for a cast when a better mechanism exists.
328
- return getBlock()->getArguments().take_front(
329
- cast<DestinationStyleOpInterface>(*this->getOperation())
330
- .getNumDpsInputs());
326
+ return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
331
327
}]
332
328
>,
333
329
InterfaceMethod<
@@ -339,13 +335,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
339
335
/*args=*/(ins),
340
336
/*methodBody=*/"",
341
337
/*defaultImplementation=*/[{
342
- // MLIR currently does not support dependent interfaces or interface
343
- // inheritance. By construction all ops with StructuredOpInterface must
344
- // implement DestinationStyleOpInterface.
345
- // TODO: reevaluate the need for a cast when a better mechanism exists.
346
- return getBlock()->getArguments().take_back(
347
- cast<DestinationStyleOpInterface>(*this->getOperation())
348
- .getNumDpsInits());
338
+ return getBlock()->getArguments().take_back($_op.getNumDpsInits());
349
339
}]
350
340
>,
351
341
InterfaceMethod<
@@ -418,13 +408,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
418
408
assert(result.getOwner() == this->getOperation());
419
409
auto indexingMaps =
420
410
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
421
- // MLIR currently does not support dependent interfaces or interface
422
- // inheritance. By construction all ops with StructuredOpInterface must
423
- // implement DestinationStyleOpInterface.
424
- // TODO: reevaluate the need for a cast when a better mechanism exists.
425
- return *(indexingMaps.begin() +
426
- cast<DestinationStyleOpInterface>(*this->getOperation())
427
- .getNumDpsInputs() +
411
+ return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
428
412
result.getResultNumber());
429
413
}]
430
414
>,
@@ -439,14 +423,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
439
423
/*methodBody=*/"",
440
424
/*defaultImplementation=*/[{
441
425
assert(opOperand->getOwner() == this->getOperation());
442
- // MLIR currently does not support dependent interfaces or interface
443
- // inheritance. By construction all ops with StructuredOpInterface must
444
- // implement DestinationStyleOpInterface.
445
- // TODO: reevaluate the need for a cast when a better mechanism exists.
446
426
int64_t resultIndex =
447
- opOperand->getOperandNumber() -
448
- cast<DestinationStyleOpInterface>(*this->getOperation())
449
- .getNumDpsInputs();
427
+ opOperand->getOperandNumber() - $_op.getNumDpsInputs();
450
428
assert(resultIndex >= 0 &&
451
429
resultIndex < this->getOperation()->getNumResults());
452
430
Operation *yieldOp = getBlock()->getTerminator();
@@ -800,80 +778,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
800
778
801
779
/// Return the index in the indexingMaps vector that corresponds to this `opOperand`
802
780
int64_t getIndexingMapIndex(OpOperand *opOperand);
803
-
804
- //========================================================================//
805
- // Forwarding functions to access interface methods from the
806
- // DestinationStyleOpInterface.
807
- // MLIR currently does not support dependent interfaces or interface
808
- // inheritance. By construction all ops with StructuredOpInterface must
809
- // implement DestinationStyleOpInterface.
810
- // TODO: reevaluate the need for a cast when a better mechanism exists.
811
- //========================================================================//
812
-
813
- int64_t getNumDpsInputs() {
814
- return cast<DestinationStyleOpInterface>(*this->getOperation())
815
- .getNumDpsInputs();
816
- }
817
-
818
- int64_t getNumDpsInits() {
819
- return cast<DestinationStyleOpInterface>(*this->getOperation())
820
- .getNumDpsInits();
821
- }
822
-
823
- OpOperandVector getDpsInputOperands() {
824
- return cast<DestinationStyleOpInterface>(*this->getOperation())
825
- .getDpsInputOperands();
826
- }
827
-
828
- OpOperand *getDpsInputOperand(int64_t i) {
829
- return cast<DestinationStyleOpInterface>(*this->getOperation())
830
- .getDpsInputOperand(i);
831
- }
832
-
833
- void setDpsInitOperand(int64_t i, Value value) {
834
- return cast<DestinationStyleOpInterface>(*this->getOperation())
835
- .setDpsInitOperand(i, value);
836
- }
837
-
838
- OpOperandVector getDpsInitOperands() {
839
- return cast<DestinationStyleOpInterface>(*this->getOperation())
840
- .getDpsInitOperands();
841
- }
842
-
843
- OpOperand *getDpsInitOperand(int64_t i) {
844
- return cast<DestinationStyleOpInterface>(*this->getOperation())
845
- .getDpsInitOperand(i);
846
- }
847
-
848
- bool isDpsInput(OpOperand *opOperand) {
849
- return cast<DestinationStyleOpInterface>(*this->getOperation())
850
- .isDpsInput(opOperand);
851
- }
852
-
853
- bool isDpsInit(OpOperand *opOperand) {
854
- return cast<DestinationStyleOpInterface>(*this->getOperation())
855
- .isDpsInit(opOperand);
856
- }
857
-
858
- bool isScalar(OpOperand *opOperand) {
859
- return cast<DestinationStyleOpInterface>(*this->getOperation())
860
- .isScalar(opOperand);
861
- }
862
-
863
- OpResult getTiedOpResult(OpOperand *opOperand) {
864
- return cast<DestinationStyleOpInterface>(*this->getOperation())
865
- .getTiedOpResult(opOperand);
866
- }
867
-
868
- bool hasBufferSemantics() {
869
- return cast<DestinationStyleOpInterface>(*this->getOperation())
870
- .hasBufferSemantics();
871
- }
872
-
873
- bool hasTensorSemantics() {
874
- return cast<DestinationStyleOpInterface>(*this->getOperation())
875
- .hasTensorSemantics();
876
- }
877
781
}];
878
782
879
783
let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
0 commit comments