Skip to content

[mlir][linalg][NFC] Make LinalgOp inherit from DestinationStyleOpInterface #66995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 7 additions & 103 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef LINALG_IR_LINALGINTERFACES
#define LINALG_IR_LINALGINTERFACES

include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/IR/OpBase.td"

// The 'LinalgContractionOpInterface' provides access to the
Expand Down Expand Up @@ -178,7 +179,8 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
}

// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
def LinalgStructuredInterface
: OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
let cppNamespace = "::mlir::linalg";
let methods = [
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -321,13 +323,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// MLIR currently does not support dependent interfaces or interface
// inheritance. By construction all ops with StructuredOpInterface must
// implement DestinationStyleOpInterface.
// TODO: reevaluate the need for a cast when a better mechanism exists.
return getBlock()->getArguments().take_front(
cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInputs());
return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
}]
>,
InterfaceMethod<
Expand All @@ -339,13 +335,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// MLIR currently does not support dependent interfaces or interface
// inheritance. By construction all ops with StructuredOpInterface must
// implement DestinationStyleOpInterface.
// TODO: reevaluate the need for a cast when a better mechanism exists.
return getBlock()->getArguments().take_back(
cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInits());
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -418,13 +408,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
assert(result.getOwner() == this->getOperation());
auto indexingMaps =
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
// MLIR currently does not support dependent interfaces or interface
// inheritance. By construction all ops with StructuredOpInterface must
// implement DestinationStyleOpInterface.
// TODO: reevaluate the need for a cast when a better mechanism exists.
return *(indexingMaps.begin() +
cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInputs() +
return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
result.getResultNumber());
}]
>,
Expand All @@ -439,14 +423,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
// MLIR currently does not support dependent interfaces or interface
// inheritance. By construction all ops with StructuredOpInterface must
// implement DestinationStyleOpInterface.
// TODO: reevaluate the need for a cast when a better mechanism exists.
int64_t resultIndex =
opOperand->getOperandNumber() -
cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInputs();
opOperand->getOperandNumber() - $_op.getNumDpsInputs();
assert(resultIndex >= 0 &&
resultIndex < this->getOperation()->getNumResults());
Operation *yieldOp = getBlock()->getTerminator();
Expand Down Expand Up @@ -800,80 +778,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {

/// Return the index in the indexingMaps vector that corresponds to this `opOperand`
int64_t getIndexingMapIndex(OpOperand *opOperand);

//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.
// MLIR currently does not support dependent interfaces or interface
// inheritance. By construction all ops with StructuredOpInterface must
// implement DestinationStyleOpInterface.
// TODO: reevaluate the need for a cast when a better mechanism exists.
//========================================================================//

int64_t getNumDpsInputs() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInputs();
}

int64_t getNumDpsInits() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInits();
}

OpOperandVector getDpsInputOperands() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getDpsInputOperands();
}

OpOperand *getDpsInputOperand(int64_t i) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getDpsInputOperand(i);
}

void setDpsInitOperand(int64_t i, Value value) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.setDpsInitOperand(i, value);
}

OpOperandVector getDpsInitOperands() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getDpsInitOperands();
}

OpOperand *getDpsInitOperand(int64_t i) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getDpsInitOperand(i);
}

bool isDpsInput(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.isDpsInput(opOperand);
}

bool isDpsInit(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.isDpsInit(opOperand);
}

bool isScalar(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.isScalar(opOperand);
}

OpResult getTiedOpResult(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getTiedOpResult(opOperand);
}

bool hasBufferSemantics() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.hasBufferSemantics();
}

bool hasTensorSemantics() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.hasTensorSemantics();
}
}];

let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
Expand Down