Skip to content

Commit c4d9869

Browse files
committed
[MLIR] Remove ArithDialect dependency from Dialect/Utils
This commit moves the inferExpandShapeOutputShape utility from the Dialect/Utils/ReshapeOpsUtils.cpp to Arith/Utils/Utils.cpp in order to remove specific dialect dependencies from the DialectUtils. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 0000fe8 commit c4d9869

File tree

7 files changed

+87
-85
lines changed

7 files changed

+87
-85
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@
2424

2525
namespace mlir {
2626

27+
using ReassociationIndices = SmallVector<int64_t, 2>;
28+
29+
/// Infer the output shape for a {memref|tensor}.expand_shape when it is
30+
/// possible to do so.
31+
///
32+
/// Note: This should *only* be used to implement
33+
/// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
34+
/// If you need to infer the output shape you should use the static method of
35+
/// `ExpandShapeOp` instead of calling this.
36+
///
37+
/// `inputShape` is the shape of the tensor or memref being expanded as a
38+
/// sequence of SSA values or constants. `expandedType` is the output shape of
39+
/// the expand_shape operation. `reassociation` is the reassociation denoting
40+
/// the output dims each input dim is mapped to.
41+
///
42+
/// Returns the output shape in `outputShape` and `staticOutputShape`, following
43+
/// the conventions for the output_shape and static_output_shape inputs to the
44+
/// expand_shape ops.
45+
std::optional<SmallVector<OpFoldResult>>
46+
inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
47+
ArrayRef<ReassociationIndices> reassociation,
48+
ArrayRef<OpFoldResult> inputShape);
49+
2750
/// Matches a ConstantIndexOp.
2851
detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
2952

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,6 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
3030
/// Attribute name for the ArrayAttr which encodes reassociation indices.
3131
constexpr StringRef getReassociationAttrName() { return "reassociation"; }
3232

33-
// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
34-
// to do so.
35-
//
36-
// Note: This should *only* be used to implement
37-
// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
38-
// If you need to infer the output shape you should use the static method of
39-
// `ExpandShapeOp` instead of calling this.
40-
//
41-
// `inputShape` is the shape of the tensor or memref being expanded as a
42-
// sequence of SSA values or constants. `expandedType` is the output shape of
43-
// the expand_shape operation. `reassociation` is the reassociation denoting
44-
// the output dims each input dim is mapped to.
45-
//
46-
// Returns the output shape in `outputShape` and `staticOutputShape`, following
47-
// the conventions for the output_shape and static_output_shape inputs to the
48-
// expand_shape ops.
49-
std::optional<SmallVector<OpFoldResult>>
50-
inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
51-
ArrayRef<ReassociationIndices> reassociation,
52-
ArrayRef<OpFoldResult> inputShape);
53-
5433
/// Compose reassociation maps that are used in pair of reshape ops where one
5534
/// is a producer and other is the consumer. Only valid to use this method when
5635
/// both the producer and consumer are collapsing dimensions or both are

mlir/lib/Dialect/Arith/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ add_mlir_dialect_library(MLIRArithUtils
88
MLIRArithDialect
99
MLIRComplexDialect
1010
MLIRDialect
11+
MLIRDialectUtils
1112
MLIRIR
1213
)

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,74 @@
1313
#include "mlir/Dialect/Arith/Utils/Utils.h"
1414
#include "mlir/Dialect/Arith/IR/Arith.h"
1515
#include "mlir/Dialect/Complex/IR/Complex.h"
16+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1617
#include "mlir/IR/ImplicitLocOpBuilder.h"
1718
#include "llvm/ADT/SmallBitVector.h"
1819
#include <numeric>
1920

2021
using namespace mlir;
2122

23+
std::optional<SmallVector<OpFoldResult>>
24+
mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
25+
ShapedType expandedType,
26+
ArrayRef<ReassociationIndices> reassociation,
27+
ArrayRef<OpFoldResult> inputShape) {
28+
29+
SmallVector<Value> outputShapeValues;
30+
SmallVector<int64_t> outputShapeInts;
31+
// For zero-rank inputs, all dims in result shape are unit extent.
32+
if (inputShape.empty()) {
33+
outputShapeInts.resize(expandedType.getRank(), 1);
34+
return getMixedValues(outputShapeInts, outputShapeValues, b);
35+
}
36+
37+
// Check for all static shapes.
38+
if (expandedType.hasStaticShape()) {
39+
ArrayRef<int64_t> staticShape = expandedType.getShape();
40+
outputShapeInts.assign(staticShape.begin(), staticShape.end());
41+
return getMixedValues(outputShapeInts, outputShapeValues, b);
42+
}
43+
44+
outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45+
for (const auto &it : llvm::enumerate(reassociation)) {
46+
ReassociationIndices indexGroup = it.value();
47+
48+
int64_t indexGroupStaticSizesProductInt = 1;
49+
bool foundDynamicShape = false;
50+
for (int64_t index : indexGroup) {
51+
int64_t outputDimSize = expandedType.getDimSize(index);
52+
// Cannot infer expanded shape with multiple dynamic dims in the
53+
// same reassociation group!
54+
if (ShapedType::isDynamic(outputDimSize)) {
55+
if (foundDynamicShape)
56+
return std::nullopt;
57+
foundDynamicShape = true;
58+
} else {
59+
outputShapeInts[index] = outputDimSize;
60+
indexGroupStaticSizesProductInt *= outputDimSize;
61+
}
62+
}
63+
if (!foundDynamicShape)
64+
continue;
65+
66+
int64_t inputIndex = it.index();
67+
// Call get<Value>() under the assumption that we're not casting
68+
// dynamism.
69+
Value indexGroupSize = inputShape[inputIndex].get<Value>();
70+
Value indexGroupStaticSizesProduct =
71+
b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
72+
Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
73+
loc, indexGroupSize, indexGroupStaticSizesProduct);
74+
outputShapeValues.push_back(dynamicDimSize);
75+
}
76+
77+
if ((int64_t)outputShapeValues.size() !=
78+
llvm::count(outputShapeInts, ShapedType::kDynamic))
79+
return std::nullopt;
80+
81+
return getMixedValues(outputShapeInts, outputShapeValues, b);
82+
}
83+
2284
/// Matches a ConstantIndexOp.
2385
/// TODO: This should probably just be a general matcher that uses matchConstant
2486
/// and checks the operation for an index type.

mlir/lib/Dialect/Utils/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ add_mlir_library(MLIRDialectUtils
88
MLIRDialectUtilsIncGen
99

1010
LINK_LIBS PUBLIC
11-
MLIRArithDialect
1211
MLIRIR
1312
)

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1010

11-
#include "mlir/Dialect/Arith/IR/Arith.h"
1211
#include "mlir/IR/AffineMap.h"
1312
#include "mlir/IR/Builders.h"
1413

@@ -17,67 +16,6 @@
1716

1817
using namespace mlir;
1918

20-
std::optional<SmallVector<OpFoldResult>>
21-
mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
22-
ShapedType expandedType,
23-
ArrayRef<ReassociationIndices> reassociation,
24-
ArrayRef<OpFoldResult> inputShape) {
25-
26-
SmallVector<Value> outputShapeValues;
27-
SmallVector<int64_t> outputShapeInts;
28-
// For zero-rank inputs, all dims in result shape are unit extent.
29-
if (inputShape.empty()) {
30-
outputShapeInts.resize(expandedType.getRank(), 1);
31-
return getMixedValues(outputShapeInts, outputShapeValues, b);
32-
}
33-
34-
// Check for all static shapes.
35-
if (expandedType.hasStaticShape()) {
36-
ArrayRef<int64_t> staticShape = expandedType.getShape();
37-
outputShapeInts.assign(staticShape.begin(), staticShape.end());
38-
return getMixedValues(outputShapeInts, outputShapeValues, b);
39-
}
40-
41-
outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
42-
for (const auto &it : llvm::enumerate(reassociation)) {
43-
ReassociationIndices indexGroup = it.value();
44-
45-
int64_t indexGroupStaticSizesProductInt = 1;
46-
bool foundDynamicShape = false;
47-
for (int64_t index : indexGroup) {
48-
int64_t outputDimSize = expandedType.getDimSize(index);
49-
// Cannot infer expanded shape with multiple dynamic dims in the
50-
// same reassociation group!
51-
if (ShapedType::isDynamic(outputDimSize)) {
52-
if (foundDynamicShape)
53-
return std::nullopt;
54-
foundDynamicShape = true;
55-
} else {
56-
outputShapeInts[index] = outputDimSize;
57-
indexGroupStaticSizesProductInt *= outputDimSize;
58-
}
59-
}
60-
if (!foundDynamicShape)
61-
continue;
62-
63-
int64_t inputIndex = it.index();
64-
// Call get<Value>() under the assumption that we're not casting
65-
// dynamism.
66-
Value indexGroupSize = inputShape[inputIndex].get<Value>();
67-
Value indexGroupStaticSizesProduct =
68-
b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
69-
Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
70-
loc, indexGroupSize, indexGroupStaticSizesProduct);
71-
outputShapeValues.push_back(dynamicDimSize);
72-
}
73-
74-
if ((int64_t)outputShapeValues.size() !=
75-
llvm::count(outputShapeInts, ShapedType::kDynamic))
76-
return std::nullopt;
77-
78-
return getMixedValues(outputShapeInts, outputShapeValues, b);
79-
}
80-
8119
std::optional<SmallVector<ReassociationIndices>>
8220
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
8321
ShapedType targetType) {

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3875,7 +3875,6 @@ cc_library(
38753875
includes = ["include"],
38763876
deps = [
38773877
":DialectUtilsIncGen",
3878-
":ArithDialect",
38793878
":IR",
38803879
":Support",
38813880
"//llvm:Support",
@@ -12635,6 +12634,7 @@ cc_library(
1263512634
deps = [
1263612635
":ArithDialect",
1263712636
":ComplexDialect",
12637+
":DialectUtils",
1263812638
":IR",
1263912639
"//llvm:Support",
1264012640
],

0 commit comments

Comments
 (0)