Skip to content

Commit c5624dc

Browse files
[mlir][Interfaces] ValueBoundsOpInterface: Handle all destination style ops (llvm#65736)
This commit provides a default implementation for all ops that implement the `DestinationStyleOpInterface`. Result values of such ops are tied to operand, and those have the same type.
1 parent ea98e1c commit c5624dc

File tree

6 files changed

+34
-45
lines changed

6 files changed

+34
-45
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,24 +270,4 @@ class ValueBoundsConstraintSet {
270270

271271
#include "mlir/Interfaces/ValueBoundsOpInterface.h.inc"
272272

273-
namespace mlir {
274-
275-
/// Default implementation for destination style ops: Tied OpResults and
276-
/// OpOperands have the same type.
277-
template <typename ConcreteOp>
278-
struct DstValueBoundsOpInterfaceExternalModel
279-
: public ValueBoundsOpInterface::ExternalModel<
280-
DstValueBoundsOpInterfaceExternalModel<ConcreteOp>, ConcreteOp> {
281-
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
282-
ValueBoundsConstraintSet &cstr) const {
283-
auto dstOp = cast<DestinationStyleOpInterface>(op);
284-
assert(value.getDefiningOp() == dstOp);
285-
286-
Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
287-
cstr.bound(value)[dim] == cstr.getExpr(tiedOperand, dim);
288-
}
289-
};
290-
291-
} // namespace mlir
292-
293273
#endif // MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_

mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,6 @@ struct IndexOpInterface
4747
}
4848
};
4949

50-
/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
51-
/// the `ValueBoundsOpInterface` with each of them.
52-
template <typename... Ops> struct LinalgValueBoundsOpInterfaceHelper {
53-
static void registerOpInterface(MLIRContext *ctx) {
54-
(Ops::template attachInterface<DstValueBoundsOpInterfaceExternalModel<Ops>>(
55-
*ctx),
56-
...);
57-
}
58-
};
59-
6050
} // namespace
6151
} // namespace linalg
6252
} // namespace mlir
@@ -65,11 +55,7 @@ void mlir::linalg::registerValueBoundsOpInterfaceExternalModels(
6555
DialectRegistry &registry) {
6656
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
6757
IndexOp::attachInterface<IndexOpInterface>(*ctx);
68-
69-
// Register all Linalg structured ops.
70-
LinalgValueBoundsOpInterfaceHelper<
71-
#define GET_OP_LIST
72-
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
73-
>::registerOpInterface(ctx);
58+
// Note: ValueBoundsOpInterface implementation is not required for ops that
59+
// implement `DestinationStyleOpInterface` (for querying shaped OpResults).
7460
});
7561
}

mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,9 @@ void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
120120
tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
121121
tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
122122
*ctx);
123-
tensor::InsertOp::attachInterface<
124-
DstValueBoundsOpInterfaceExternalModel<tensor::InsertOp>>(*ctx);
125-
tensor::InsertSliceOp::attachInterface<
126-
DstValueBoundsOpInterfaceExternalModel<tensor::InsertSliceOp>>(*ctx);
127123
tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
128124
tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
125+
// Note: ValueBoundsOpInterface implementation is not required for ops that
126+
// implement `DestinationStyleOpInterface` (for querying shaped OpResults).
129127
});
130128
}

mlir/lib/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ add_mlir_library(MLIRValueBoundsOpInterface
9393
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
9494

9595
DEPENDS
96+
MLIRDestinationStyleOpInterface
9697
MLIRValueBoundsOpInterfaceIncGen
9798

9899
LINK_LIBS PUBLIC

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/IR/BuiltinTypes.h"
1212
#include "mlir/IR/Matchers.h"
13+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1314
#include "llvm/ADT/APSInt.h"
1415
#include "llvm/Support/Debug.h"
1516

@@ -191,13 +192,23 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
191192
// the worklist.
192193
auto valueBoundsOp =
193194
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
194-
if (!valueBoundsOp)
195+
if (valueBoundsOp) {
196+
if (dim == kIndexValue) {
197+
valueBoundsOp.populateBoundsForIndexValue(value, *this);
198+
} else {
199+
valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
200+
}
195201
continue;
196-
if (dim == kIndexValue) {
197-
valueBoundsOp.populateBoundsForIndexValue(value, *this);
198-
} else {
199-
valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
200202
}
203+
204+
// If the op does not implement `ValueBoundsOpInterface`, check if it
205+
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
206+
// tied to OpOperands. Tied values have the same shape.
207+
auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>();
208+
if (!dstOp || dim == kIndexValue)
209+
continue;
210+
Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
211+
bound(value)[dim] == getExpr(tiedOperand, dim);
201212
}
202213
}
203214

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
2+
// RUN: -split-input-file | FileCheck %s
3+
4+
// CHECK-LABEL: func @vector_transfer_write(
5+
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
6+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
7+
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]]
8+
// CHECK: return %[[dim]]
9+
func.func @vector_transfer_write(%t: tensor<?xf32>, %v: vector<5xf32>, %pos: index) -> index {
10+
%0 = vector.transfer_write %v, %t[%pos] : vector<5xf32>, tensor<?xf32>
11+
%1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
12+
return %1 : index
13+
}

0 commit comments

Comments
 (0)