Skip to content

Commit b4d08df

Browse files
committed
[mlir] Remove incorrect builders for ExpandShapeOp
ExpandShapeOp builder cannot infer the result type since it doesn't know how the dimension needs to be split. Remove this builder so that it doesn't get used accidently. Also remove one potential path using it in generic fusion. Differential Revision: https://reviews.llvm.org/D122019
1 parent d898c95 commit b4d08df

File tree

5 files changed

+114
-102
lines changed

5 files changed

+114
-102
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,41 +1193,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
11931193
[NoSideEffect, ViewLikeOpInterface])>,
11941194
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
11951195
Results<(outs AnyStridedMemRef:$result)>{
1196-
let builders = [
1197-
// Builders for a contracting reshape whose result type is computed from
1198-
// `src` and `reassociation`.
1199-
OpBuilder<(ins "Value":$src,
1200-
"ArrayRef<ReassociationIndices>":$reassociation,
1201-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1202-
OpBuilder<(ins "Value":$src,
1203-
"ArrayRef<ReassociationExprs>":$reassociation,
1204-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1205-
[{
1206-
auto reassociationMaps =
1207-
convertReassociationMapsToIndices($_builder, reassociation);
1208-
build($_builder, $_state, src, reassociationMaps, attrs);
1209-
}]>,
1210-
1211-
// Builders for a reshape whose result type is passed explicitly. This may
1212-
// be either a contracting or expanding reshape.
1213-
OpBuilder<(ins "Type":$resultType, "Value":$src,
1214-
"ArrayRef<ReassociationIndices>":$reassociation,
1215-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1216-
[{
1217-
build($_builder, $_state, resultType, src, attrs);
1218-
$_state.addAttribute("reassociation",
1219-
getReassociationIndicesAttribute($_builder, reassociation));
1220-
}]>,
1221-
OpBuilder<(ins "Type":$resultType, "Value":$src,
1222-
"ArrayRef<ReassociationExprs>":$reassociation,
1223-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1224-
[{
1225-
auto reassociationMaps =
1226-
convertReassociationMapsToIndices($_builder, reassociation);
1227-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1228-
}]>
1229-
];
1230-
1196+
12311197
code commonExtraClassDeclaration = [{
12321198
SmallVector<AffineMap, 4> getReassociationMaps();
12331199
SmallVector<ReassociationExprs, 4> getReassociationExprs();
@@ -1288,6 +1254,25 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
12881254
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
12891255
```
12901256
}];
1257+
let builders = [
1258+
// Builders using ReassociationIndices.
1259+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1260+
"ArrayRef<ReassociationIndices>":$reassociation,
1261+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1262+
[{
1263+
build($_builder, $_state, resultType, src, attrs);
1264+
$_state.addAttribute("reassociation",
1265+
getReassociationIndicesAttribute($_builder, reassociation));
1266+
}]>,
1267+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1268+
"ArrayRef<ReassociationExprs>":$reassociation,
1269+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1270+
[{
1271+
auto reassociationMaps =
1272+
convertReassociationMapsToIndices($_builder, reassociation);
1273+
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1274+
}]>
1275+
];
12911276
let extraClassDeclaration = commonExtraClassDeclaration;
12921277
let hasVerifier = 1;
12931278
}
@@ -1326,6 +1311,39 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
13261311
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
13271312
```
13281313
}];
1314+
let builders = [
1315+
// Builders for a contracting reshape whose result type is computed from
1316+
// `src` and `reassociation`.
1317+
OpBuilder<(ins "Value":$src,
1318+
"ArrayRef<ReassociationIndices>":$reassociation,
1319+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1320+
OpBuilder<(ins "Value":$src,
1321+
"ArrayRef<ReassociationExprs>":$reassociation,
1322+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1323+
[{
1324+
auto reassociationMaps =
1325+
convertReassociationMapsToIndices($_builder, reassociation);
1326+
build($_builder, $_state, src, reassociationMaps, attrs);
1327+
}]>,
1328+
1329+
// Builders for a reshape whose result type is passed explicitly.
1330+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1331+
"ArrayRef<ReassociationIndices>":$reassociation,
1332+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1333+
[{
1334+
build($_builder, $_state, resultType, src, attrs);
1335+
$_state.addAttribute("reassociation",
1336+
getReassociationIndicesAttribute($_builder, reassociation));
1337+
}]>,
1338+
OpBuilder<(ins "Type":$resultType, "Value":$src,
1339+
"ArrayRef<ReassociationExprs>":$reassociation,
1340+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1341+
[{
1342+
auto reassociationMaps =
1343+
convertReassociationMapsToIndices($_builder, reassociation);
1344+
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1345+
}]>
1346+
];
13291347
let extraClassDeclaration = commonExtraClassDeclaration;
13301348
let hasVerifier = 1;
13311349
}

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -678,41 +678,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
678678
Tensor_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
679679
Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
680680
Results<(outs AnyTensor:$result)> {
681-
let builders = [
682-
// Builders for a contracting reshape whose result type is computed from
683-
// `src` and `reassociation`.
684-
OpBuilder<(ins "Value":$src,
685-
"ArrayRef<ReassociationIndices>":$reassociation,
686-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
687-
OpBuilder<(ins "Value":$src,
688-
"ArrayRef<ReassociationExprs>":$reassociation,
689-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
690-
[{
691-
auto reassociationMaps =
692-
convertReassociationMapsToIndices($_builder, reassociation);
693-
build($_builder, $_state, src, reassociationMaps, attrs);
694-
}]>,
695-
696-
// Builders for a reshape whose result type is passed explicitly. This may
697-
// be either a contracting or expanding reshape.
698-
OpBuilder<(ins "Type":$resultType, "Value":$src,
699-
"ArrayRef<ReassociationIndices>":$reassociation,
700-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
701-
[{
702-
build($_builder, $_state, resultType, src, attrs);
703-
$_state.addAttribute("reassociation",
704-
getReassociationIndicesAttribute($_builder, reassociation));
705-
}]>,
706-
OpBuilder<(ins "Type":$resultType, "Value":$src,
707-
"ArrayRef<ReassociationExprs>":$reassociation,
708-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
709-
[{
710-
auto reassociationMaps =
711-
convertReassociationMapsToIndices($_builder, reassociation);
712-
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
713-
}]>
714-
];
715-
681+
716682
code commonExtraClassDeclaration = [{
717683
static StringRef getReassociationAttrName() { return "reassociation"; }
718684
SmallVector<AffineMap, 4> getReassociationMaps();
@@ -768,6 +734,26 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
768734
: tensor<?x?xf32> into tensor<?x?x?xf32>
769735
```
770736
}];
737+
let builders = [
738+
// Builders using ReassociationIndices.
739+
OpBuilder<(ins "Type":$resultType, "Value":$src,
740+
"ArrayRef<ReassociationIndices>":$reassociation,
741+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
742+
[{
743+
build($_builder, $_state, resultType, src, attrs);
744+
$_state.addAttribute("reassociation",
745+
getReassociationIndicesAttribute($_builder, reassociation));
746+
}]>,
747+
OpBuilder<(ins "Type":$resultType, "Value":$src,
748+
"ArrayRef<ReassociationExprs>":$reassociation,
749+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
750+
[{
751+
auto reassociationMaps =
752+
convertReassociationMapsToIndices($_builder, reassociation);
753+
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
754+
}]>
755+
];
756+
771757
let extraClassDeclaration = commonExtraClassDeclaration;
772758
let hasVerifier = 1;
773759
}
@@ -797,6 +783,40 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
797783
: tensor<?x?x?xf32> into tensor<?x?xf32>
798784
```
799785
}];
786+
let builders = [
787+
// Builders for a contracting reshape whose result type is computed from
788+
// `src` and `reassociation`.
789+
OpBuilder<(ins "Value":$src,
790+
"ArrayRef<ReassociationIndices>":$reassociation,
791+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
792+
OpBuilder<(ins "Value":$src,
793+
"ArrayRef<ReassociationExprs>":$reassociation,
794+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
795+
[{
796+
auto reassociationMaps =
797+
convertReassociationMapsToIndices($_builder, reassociation);
798+
build($_builder, $_state, src, reassociationMaps, attrs);
799+
}]>,
800+
801+
// Builders for a reshape whose result type is passed explicitly.
802+
OpBuilder<(ins "Type":$resultType, "Value":$src,
803+
"ArrayRef<ReassociationIndices>":$reassociation,
804+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
805+
[{
806+
build($_builder, $_state, resultType, src, attrs);
807+
$_state.addAttribute("reassociation",
808+
getReassociationIndicesAttribute($_builder, reassociation));
809+
}]>,
810+
OpBuilder<(ins "Type":$resultType, "Value":$src,
811+
"ArrayRef<ReassociationExprs>":$reassociation,
812+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
813+
[{
814+
auto reassociationMaps =
815+
convertReassociationMapsToIndices($_builder, reassociation);
816+
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
817+
}]>
818+
];
819+
800820
let extraClassDeclaration = commonExtraClassDeclaration;
801821
let hasVerifier = 1;
802822
}

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,21 +2223,19 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
22232223

22242224
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
22252225
RewritePatternSet &patterns) {
2226-
patterns
2227-
.add<FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2228-
FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
2229-
FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2230-
FoldConsumerReshapeOpByLinearization<false, tensor::ExpandShapeOp>>(
2231-
patterns.getContext());
2226+
patterns.add<
2227+
FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
2228+
FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
2229+
FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
2230+
patterns.getContext());
22322231
}
22332232

22342233
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
22352234
RewritePatternSet &patterns) {
22362235
patterns
22372236
.add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
22382237
FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
2239-
FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
2240-
FoldConsumerReshapeOpByLinearization<true, tensor::ExpandShapeOp>>(
2238+
FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
22412239
patterns.getContext());
22422240
}
22432241

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,18 +1669,6 @@ computeReshapeCollapsedType(MemRefType type,
16691669
AffineMapAttr::get(layout)));
16701670
}
16711671

1672-
void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1673-
ArrayRef<ReassociationIndices> reassociation,
1674-
ArrayRef<NamedAttribute> attrs) {
1675-
auto memRefType = src.getType().cast<MemRefType>();
1676-
auto resultType = computeReshapeCollapsedType(
1677-
memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1678-
b.getContext(), reassociation)));
1679-
build(b, result, resultType, src, attrs);
1680-
result.addAttribute(getReassociationAttrName(),
1681-
getReassociationIndicesAttribute(b, reassociation));
1682-
}
1683-
16841672
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
16851673
ArrayRef<ReassociationIndices> reassociation,
16861674
ArrayRef<NamedAttribute> attrs) {

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -817,18 +817,6 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
817817
getReassociationIndicesAttribute(b, reassociation));
818818
}
819819

820-
void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
821-
ArrayRef<ReassociationIndices> reassociation,
822-
ArrayRef<NamedAttribute> attrs) {
823-
auto resultType = computeTensorReshapeCollapsedType(
824-
src.getType().cast<RankedTensorType>(),
825-
getSymbolLessAffineMaps(
826-
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
827-
build(b, result, resultType, src, attrs);
828-
result.addAttribute(getReassociationAttrName(),
829-
getReassociationIndicesAttribute(b, reassociation));
830-
}
831-
832820
template <typename TensorReshapeOp, bool isExpansion = std::is_same<
833821
TensorReshapeOp, ExpandShapeOp>::value>
834822
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,

0 commit comments

Comments
 (0)