Skip to content

Commit 9f02751

Browse files
committed
Add additional static value utils
Signed-off-by: Max Dawkins <[email protected]>
1 parent 7116ece commit 9f02751

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
9292

9393
/// Return true if `ofr` is constant integer equal to `value`.
9494
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
95+
/// Return true if all of `ofrs` are constant integers equal to `value`.
96+
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
97+
/// Return true if all of `ofrs` are constant integers equal to the
98+
/// corresponding value in `values`.
99+
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
100+
ArrayRef<int64_t> values);
95101

96102
/// Return true if ofr1 and ofr2 are the same integer constant attribute
97103
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -640,8 +640,6 @@ struct InsertOpInterface
640640
template <typename InsertOpTy>
641641
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
642642
OpOperand &opOperand) {
643-
RankedTensorType destType = insertSliceOp.getDestType();
644-
645643
// The source is always read.
646644
if (opOperand == insertSliceOp.getSourceMutable())
647645
return true;
@@ -652,16 +650,12 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
652650
// Dest is not read if it is entirely overwritten. E.g.:
653651
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
654652
bool allOffsetsZero =
655-
llvm::all_of(insertSliceOp.getMixedOffsets(),
656-
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
657-
bool sizesMatchDestSizes = llvm::all_of(
658-
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
659-
return getConstantIntValue(it.value()) ==
660-
destType.getDimSize(it.index());
661-
});
653+
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
654+
RankedTensorType destType = insertSliceOp.getDestType();
655+
bool sizesMatchDestSizes =
656+
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
662657
bool allStridesOne =
663-
llvm::all_of(insertSliceOp.getMixedStrides(),
664-
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
658+
areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
665659
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
666660
}
667661

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ namespace mlir {
1616
namespace tensor {
1717
namespace {
1818

19-
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
20-
return llvm::all_of(
21-
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
22-
}
23-
2419
/// Returns the number of shape sizes that is either dynamic or greater than 1.
2520
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
2621
return llvm::count_if(

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/IR/Matchers.h"
1111
#include "mlir/Support/LLVM.h"
1212
#include "llvm/ADT/APSInt.h"
13+
#include "llvm/ADT/STLExtras.h"
1314
#include "llvm/Support/MathExtras.h"
1415

1516
namespace mlir {
@@ -137,6 +138,22 @@ bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
137138
return val && *val == value;
138139
}
139140

141+
/// Return true if all of `ofrs` are constant integers equal to `value`.
142+
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
143+
return llvm::all_of(
144+
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
145+
}
146+
147+
/// Return true if all of `ofrs` are constant integers equal to the
148+
/// corresponding value in `values`.
149+
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
150+
ArrayRef<int64_t> values) {
151+
if (ofrs.size() != values.size())
152+
return false;
153+
std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
154+
return constOfrs && llvm::equal(constOfrs.value(), values);
155+
}
156+
140157
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
141158
/// or the same SSA value.
142159
/// Ignore integer bitwidth and type mismatch that come from the fact there is

0 commit comments

Comments
 (0)