Skip to content

[mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface #145313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/RawOstreamExtras.h"
Expand Down
180 changes: 29 additions & 151 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define LINALG_IR_LINALGINTERFACES

include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/IR/OpBase.td"

// The 'LinalgContractionOpInterface' provides access to the
Expand Down Expand Up @@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
];
}

def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
let description = [{
Interface for operations that connect an iteration domain to operands via
affine maps. Provides methods to access indexing maps between iteration
domain and operand index spaces.
}];
let cppNamespace = "::mlir::linalg";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"getIndexingMaps"
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps within the current operation.
}],
/*retTy=*/"SmallVector<AffineMap>",
/*methodName=*/"getIndexingMapsArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = $_op.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {range.begin(), range.end()};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the input or output indexing map for `opOperand`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getMatchingIndexingMap",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
auto indexingMaps =
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
return *(indexingMaps.begin() + opOperand->getOperandNumber());
}]
>,
];
}

// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface
: OpInterface<"LinalgOp", [
DestinationStyleOpInterface,
IndexingMapOpInterface
]> {
: OpInterface<"LinalgOp",
[DestinationStyleOpInterface, IndexingMapOpInterface]
> {
let cppNamespace = "::mlir::linalg";
let methods = [
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -464,30 +417,6 @@ def LinalgStructuredInterface
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the `opOperand` shape or an empty vector for scalars or vectors
not wrapped within a tensor or a memref.
}],
/*retTy=*/"ArrayRef<int64_t>",
/*methodName=*/"getShape",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
Type t = opOperand->get().getType();
// A VectorType is an elemental type, do not consider its rank for the operand.
if (isa<VectorType>(t))
return {};
if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
// Failsafe.
assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
"expected a ranked tensor or memref in LinalgInterface::getRank");
return shapedType.getShape();
}
return {};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the block argument for an `opOperand`.
Expand Down Expand Up @@ -620,7 +549,12 @@ def LinalgStructuredInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
}
}
return false;
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -738,53 +672,6 @@ def LinalgStructuredInterface
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Hook to provide a custom AffineMap used to compute all the operand
subshapes given loop bounds. This is used to answer the question: "given
an iteration space over the codomain, what are the subshapes of the
operands involved in the computation".
The default behavior is to just concatenate all the indexing maps.
A custom AffineMap allows providing a map that can be used to
compute subshapes even in cases where the concatenation of indexing maps
(i.e. the data traversal order) is not a simple permutation of the loop
traversal order. It is then possible to define ops with skewed data
traversal order for which we can still easily compute hyperrectangular
loop bounds and subviews.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getLoopsToShapesMap",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto maps = $_op.getIndexingMapsArray();
return concatAffineMaps(maps, $_op.getContext());
}]
>,
InterfaceMethod<
/*desc=*/[{
Hook to provide a custom AffineMap used to construct the
hyperrectangular loop iteration space given all the operand subshapes.
This is used to answer the question:
"Given a list of operand ranges, what is the subportion of the iteration
space involved in the computation".
This is the inverse problem of `getLoopsToShapesMap`.
Return the empty AffineMap when such an AffineMap cannot be constructed.
The default behavior is based on a very simple inference procedure that
only works with permutation affine maps.
A more advanced Tensor-Comprehension like inference is possible but has
proven to be ambiguous in unfavorable case.
A safer and more robust alternative is to allow each op to define
its own AffineMap.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getShapesToLoopsMap",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return inversePermutation(getLoopsToShapesMap());
}]
>,
InterfaceMethod<
/*desc=*/[{
Checks if the given operands can be dropped, and the remaining
Expand All @@ -798,39 +685,30 @@ def LinalgStructuredInterface
return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
}]
>,
//===------------------------------------------------------------------===//
// IndexingMapOpInterface interface methods implementation.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Like `getShape`, but only returns statically-known information, without
generating any new IR. For each shape dimension, returns >=0 if that
dimension is statically known, or ShapedType::kDynamic otherwise.
}],
/*retTy=*/"SmallVector<int64_t>",
/*methodName=*/"getStaticShape",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t> res;
for (OpOperand &opOperand : this->getOperation()->getOpOperands())
llvm::append_range(res, getShape(&opOperand));
return res;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the statically-known loop ranges. Composes
`getShapesToLoopsMap()` with the result of `getStaticShape`.
Returns ShapedType::kDynamic for non-statically-known loop ranges.
This is expected to be called by a valid Linalg op
Return the `opOperand` shape or an empty vector for scalars or vectors
not wrapped within a tensor or a memref.
}],
/*retTy=*/"SmallVector<int64_t, 4>",
/*methodName=*/"getStaticLoopRanges",
/*args=*/(ins),
/*retTy=*/"ArrayRef<int64_t>",
/*methodName=*/"getShape",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t> viewSizes = getStaticShape();
AffineMap invertedMap = getShapesToLoopsMap();
assert(invertedMap && "expected a valid Linalg op to call the method");
return invertedMap.compose(viewSizes);
Type t = opOperand->get().getType();
// A VectorType is an elemental type, do not consider its rank for the operand.
if (isa<VectorType>(t))
return {};
if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
// Failsafe.
assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
"expected a ranked tensor or memref in LinalgInterface::getRank");
return shapedType.getShape();
}
return {};
}]
>,
//===------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand All @@ -33,6 +34,7 @@ include "mlir/IR/EnumAttr.td"
// than the current set: {*, +}.
def Vector_ContractionOp :
Vector_Op<"contract", [
IndexingMapOpInterface,
Pure,
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
Expand Down Expand Up @@ -207,6 +209,16 @@ def Vector_ContractionOp :
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
return {range.begin(), range.end()};
}

//===------------------------------------------------------------------===//
// IndexingMapOpInterface interface methods implementation.
//===------------------------------------------------------------------===//
ArrayRef<int64_t> getShape(OpOperand * opOperand) {
Type t = opOperand->get().getType();
if (auto vt = dyn_cast<VectorType>(t))
return vt.getShape();
return {};
}
}];

let hasCanonicalizer = 1;
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
Expand Down
27 changes: 27 additions & 0 deletions mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_

#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace detail {
/// Verify that `op` conforms to the invariants of StructuredOpInterface
LogicalResult verifyIndexingMapOpInterface(Operation *op);
} // namespace detail
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"

#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
Loading
Loading