Skip to content

Commit 315ddc5

Browse files
[mlir][linalg][NFC] Make LinalgOp inherit from DestinationStyleOpInterface (#66995)
Dependent interfaces have been added a while ago and these TODOs can be addressed now.
1 parent 13e9a56 commit 315ddc5

File tree

1 file changed

+7
-103
lines changed

1 file changed

+7
-103
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 7 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LINALG_IR_LINALGINTERFACES
1414
#define LINALG_IR_LINALGINTERFACES
1515

16+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1617
include "mlir/IR/OpBase.td"
1718

1819
// The 'LinalgContractionOpInterface' provides access to the
@@ -178,7 +179,8 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
178179
}
179180

180181
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
181-
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
182+
def LinalgStructuredInterface
183+
: OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
182184
let cppNamespace = "::mlir::linalg";
183185
let methods = [
184186
//===------------------------------------------------------------------===//
@@ -321,13 +323,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
321323
/*args=*/(ins),
322324
/*methodBody=*/"",
323325
/*defaultImplementation=*/[{
324-
// MLIR currently does not support dependent interfaces or interface
325-
// inheritance. By construction all ops with StructuredOpInterface must
326-
// implement DestinationStyleOpInterface.
327-
// TODO: reevaluate the need for a cast when a better mechanism exists.
328-
return getBlock()->getArguments().take_front(
329-
cast<DestinationStyleOpInterface>(*this->getOperation())
330-
.getNumDpsInputs());
326+
return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
331327
}]
332328
>,
333329
InterfaceMethod<
@@ -339,13 +335,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
339335
/*args=*/(ins),
340336
/*methodBody=*/"",
341337
/*defaultImplementation=*/[{
342-
// MLIR currently does not support dependent interfaces or interface
343-
// inheritance. By construction all ops with StructuredOpInterface must
344-
// implement DestinationStyleOpInterface.
345-
// TODO: reevaluate the need for a cast when a better mechanism exists.
346-
return getBlock()->getArguments().take_back(
347-
cast<DestinationStyleOpInterface>(*this->getOperation())
348-
.getNumDpsInits());
338+
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
349339
}]
350340
>,
351341
InterfaceMethod<
@@ -418,13 +408,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
418408
assert(result.getOwner() == this->getOperation());
419409
auto indexingMaps =
420410
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
421-
// MLIR currently does not support dependent interfaces or interface
422-
// inheritance. By construction all ops with StructuredOpInterface must
423-
// implement DestinationStyleOpInterface.
424-
// TODO: reevaluate the need for a cast when a better mechanism exists.
425-
return *(indexingMaps.begin() +
426-
cast<DestinationStyleOpInterface>(*this->getOperation())
427-
.getNumDpsInputs() +
411+
return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
428412
result.getResultNumber());
429413
}]
430414
>,
@@ -439,14 +423,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
439423
/*methodBody=*/"",
440424
/*defaultImplementation=*/[{
441425
assert(opOperand->getOwner() == this->getOperation());
442-
// MLIR currently does not support dependent interfaces or interface
443-
// inheritance. By construction all ops with StructuredOpInterface must
444-
// implement DestinationStyleOpInterface.
445-
// TODO: reevaluate the need for a cast when a better mechanism exists.
446426
int64_t resultIndex =
447-
opOperand->getOperandNumber() -
448-
cast<DestinationStyleOpInterface>(*this->getOperation())
449-
.getNumDpsInputs();
427+
opOperand->getOperandNumber() - $_op.getNumDpsInputs();
450428
assert(resultIndex >= 0 &&
451429
resultIndex < this->getOperation()->getNumResults());
452430
Operation *yieldOp = getBlock()->getTerminator();
@@ -800,80 +778,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
800778

801779
/// Return the index in the indexingMaps vector that corresponds to this `opOperand`
802780
int64_t getIndexingMapIndex(OpOperand *opOperand);
803-
804-
//========================================================================//
805-
// Forwarding functions to access interface methods from the
806-
// DestinationStyleOpInterface.
807-
// MLIR currently does not support dependent interfaces or interface
808-
// inheritance. By construction all ops with StructuredOpInterface must
809-
// implement DestinationStyleOpInterface.
810-
// TODO: reevaluate the need for a cast when a better mechanism exists.
811-
//========================================================================//
812-
813-
int64_t getNumDpsInputs() {
814-
return cast<DestinationStyleOpInterface>(*this->getOperation())
815-
.getNumDpsInputs();
816-
}
817-
818-
int64_t getNumDpsInits() {
819-
return cast<DestinationStyleOpInterface>(*this->getOperation())
820-
.getNumDpsInits();
821-
}
822-
823-
OpOperandVector getDpsInputOperands() {
824-
return cast<DestinationStyleOpInterface>(*this->getOperation())
825-
.getDpsInputOperands();
826-
}
827-
828-
OpOperand *getDpsInputOperand(int64_t i) {
829-
return cast<DestinationStyleOpInterface>(*this->getOperation())
830-
.getDpsInputOperand(i);
831-
}
832-
833-
void setDpsInitOperand(int64_t i, Value value) {
834-
return cast<DestinationStyleOpInterface>(*this->getOperation())
835-
.setDpsInitOperand(i, value);
836-
}
837-
838-
OpOperandVector getDpsInitOperands() {
839-
return cast<DestinationStyleOpInterface>(*this->getOperation())
840-
.getDpsInitOperands();
841-
}
842-
843-
OpOperand *getDpsInitOperand(int64_t i) {
844-
return cast<DestinationStyleOpInterface>(*this->getOperation())
845-
.getDpsInitOperand(i);
846-
}
847-
848-
bool isDpsInput(OpOperand *opOperand) {
849-
return cast<DestinationStyleOpInterface>(*this->getOperation())
850-
.isDpsInput(opOperand);
851-
}
852-
853-
bool isDpsInit(OpOperand *opOperand) {
854-
return cast<DestinationStyleOpInterface>(*this->getOperation())
855-
.isDpsInit(opOperand);
856-
}
857-
858-
bool isScalar(OpOperand *opOperand) {
859-
return cast<DestinationStyleOpInterface>(*this->getOperation())
860-
.isScalar(opOperand);
861-
}
862-
863-
OpResult getTiedOpResult(OpOperand *opOperand) {
864-
return cast<DestinationStyleOpInterface>(*this->getOperation())
865-
.getTiedOpResult(opOperand);
866-
}
867-
868-
bool hasBufferSemantics() {
869-
return cast<DestinationStyleOpInterface>(*this->getOperation())
870-
.hasBufferSemantics();
871-
}
872-
873-
bool hasTensorSemantics() {
874-
return cast<DestinationStyleOpInterface>(*this->getOperation())
875-
.hasTensorSemantics();
876-
}
877781
}];
878782

879783
let verify = [{ return detail::verifyStructuredOpInterface($_op); }];

0 commit comments

Comments
 (0)