@@ -1087,43 +1087,44 @@ LogicalResult GenericOp::verify() { return success(); }
1087
1087
1088
1088
namespace {
1089
1089
1090
- // / Remove generic operations (on tensors) that are just copying
1090
+ // / Remove any linalg operation (on tensors) that are just copying
1091
1091
// / the values from inputs to the results. Requirements are
1092
1092
// / 1) All iterator types are parallel
1093
1093
// / 2) The body contains just a yield operation with the yielded values being
1094
1094
// / the arguments corresponding to the operands.
1095
- struct EraseIdentityGenericOp : public OpRewritePattern <GenericOp> {
1096
- using OpRewritePattern<GenericOp>::OpRewritePattern;
1095
+ template <typename OpTy>
1096
+ struct EraseIdentityLinalgOp : public OpRewritePattern <OpTy> {
1097
+ using OpRewritePattern<OpTy>::OpRewritePattern;
1097
1098
1098
- LogicalResult matchAndRewrite (GenericOp genericOp ,
1099
+ LogicalResult matchAndRewrite (OpTy linalgOp ,
1099
1100
PatternRewriter &rewriter) const override {
1100
1101
// Check all indexing maps are identity.
1101
- if (llvm::any_of (genericOp .getIndexingMapsArray (),
1102
+ if (llvm::any_of (linalgOp .getIndexingMapsArray (),
1102
1103
[](AffineMap map) { return !map.isIdentity (); }))
1103
1104
return failure ();
1104
1105
1105
1106
// Check that the body of the linalg operation is just a linalg.yield
1106
1107
// operation.
1107
- Block &body = genericOp. getRegion ().front ();
1108
+ Block &body = linalgOp-> getRegion (0 ).front ();
1108
1109
if (!llvm::hasSingleElement (body))
1109
1110
return failure ();
1110
1111
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator ());
1111
1112
if (!yieldOp)
1112
1113
return failure ();
1113
1114
1114
1115
// In the buffer case, we need to check exact buffer equality.
1115
- if (genericOp .hasPureBufferSemantics ()) {
1116
- if (genericOp .getNumDpsInputs () == 1 && genericOp .getNumDpsInits () == 1 &&
1117
- genericOp .getDpsInputOperand (0 )->get () ==
1118
- genericOp .getDpsInitOperand (0 )->get ()) {
1119
- rewriter.eraseOp (genericOp );
1116
+ if (linalgOp .hasPureBufferSemantics ()) {
1117
+ if (linalgOp .getNumDpsInputs () == 1 && linalgOp .getNumDpsInits () == 1 &&
1118
+ linalgOp .getDpsInputOperand (0 )->get () ==
1119
+ linalgOp .getDpsInitOperand (0 )->get ()) {
1120
+ rewriter.eraseOp (linalgOp );
1120
1121
return success ();
1121
1122
}
1122
1123
return failure ();
1123
1124
}
1124
1125
1125
1126
// Mixed semantics is not supported yet.
1126
- if (!genericOp .hasPureTensorSemantics ())
1127
+ if (!linalgOp .hasPureTensorSemantics ())
1127
1128
return failure ();
1128
1129
1129
1130
// Get the argument number of the returned values. That is the operand
@@ -1134,8 +1135,8 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1134
1135
if (!yieldArg || yieldArg.getOwner () != &body)
1135
1136
return failure ();
1136
1137
unsigned argumentNumber = yieldArg.getArgNumber ();
1137
- Value returnedArg = genericOp ->getOperand (argumentNumber);
1138
- Type resultType = genericOp ->getResult (yieldVal.index ()).getType ();
1138
+ Value returnedArg = linalgOp ->getOperand (argumentNumber);
1139
+ Type resultType = linalgOp ->getResult (yieldVal.index ()).getType ();
1139
1140
// The input can have a different type than the result, e.g. a dynamic
1140
1141
// input dimension can be turned into a static output dimension.
1141
1142
Type returnType = returnedArg.getType ();
@@ -1145,21 +1146,21 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1145
1146
if (sparse_tensor::getSparseTensorEncoding (returnType) ||
1146
1147
sparse_tensor::getSparseTensorEncoding (resultType))
1147
1148
returnedArg = rewriter.create <sparse_tensor::ConvertOp>(
1148
- genericOp .getLoc (), resultType, returnedArg);
1149
+ linalgOp .getLoc (), resultType, returnedArg);
1149
1150
else {
1150
1151
if (!tensor::CastOp::areCastCompatible (returnedArg.getType (),
1151
1152
resultType))
1152
1153
return failure ();
1153
1154
returnedArg = rewriter.create <tensor::CastOp>(
1154
- genericOp .getLoc (), resultType, returnedArg);
1155
+ linalgOp .getLoc (), resultType, returnedArg);
1155
1156
}
1156
1157
}
1157
1158
returnedArgs.push_back (returnedArg);
1158
1159
}
1159
1160
1160
- if (returnedArgs.size () != genericOp ->getNumResults ())
1161
+ if (returnedArgs.size () != linalgOp ->getNumResults ())
1161
1162
return failure ();
1162
- rewriter.replaceOp (genericOp , returnedArgs);
1163
+ rewriter.replaceOp (linalgOp , returnedArgs);
1163
1164
return success ();
1164
1165
}
1165
1166
};
@@ -1168,7 +1169,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1168
1169
1169
1170
void GenericOp::getCanonicalizationPatterns (RewritePatternSet &results,
1170
1171
MLIRContext *context) {
1171
- results.add <EraseIdentityGenericOp >(context);
1172
+ results.add <EraseIdentityLinalgOp<GenericOp> >(context);
1172
1173
}
1173
1174
1174
1175
LogicalResult GenericOp::fold (FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
@@ -1907,6 +1908,11 @@ void BroadcastOp::getEffects(
1907
1908
getDpsInits ());
1908
1909
}
1909
1910
1911
+ void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
1912
+ MLIRContext *context) {
1913
+ results.add <EraseIdentityLinalgOp<BroadcastOp>>(context);
1914
+ }
1915
+
1910
1916
// ===----------------------------------------------------------------------===//
1911
1917
// YieldOp
1912
1918
// ===----------------------------------------------------------------------===//
0 commit comments