@@ -131,29 +131,49 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
131
131
return success ();
132
132
}
133
133
134
- struct TransposeNoOp : public OpRewritePattern <tosa::TransposeOp> {
134
+ struct ConsolidateTransposeOptimization
135
+ : public OpRewritePattern<tosa::TransposeOp> {
135
136
using OpRewritePattern::OpRewritePattern;
136
137
137
- LogicalResult matchAndRewrite (tosa::TransposeOp op ,
138
+ LogicalResult matchAndRewrite (tosa::TransposeOp transposeOp ,
138
139
PatternRewriter &rewriter) const override {
139
- auto perm = op.getPerms ();
140
+ // Input is also TransposeOp - transpose(transpose(A)).
141
+ auto innerTranspose =
142
+ transposeOp.getInput1 ().getDefiningOp <tosa::TransposeOp>();
143
+ if (!innerTranspose)
144
+ return rewriter.notifyMatchFailure (transposeOp,
145
+ " input must be transpose operation" );
146
+
147
+ SmallVector<int64_t > transposePerms, innerTransposePerms;
148
+ if (transposeOp.getConstantPerms (transposePerms).failed ())
149
+ return rewriter.notifyMatchFailure (transposeOp,
150
+ " transpose perms must be constant" );
151
+ if (innerTranspose.getConstantPerms (innerTransposePerms).failed ())
152
+ return rewriter.notifyMatchFailure (
153
+ transposeOp, " inner transpose perms must be constant" );
154
+ if (transposePerms.size () != innerTransposePerms.size ())
155
+ return rewriter.notifyMatchFailure (
156
+ transposeOp,
157
+ " transpose and inner transpose perms sizes must be equal" );
158
+ if (transposePerms.empty ())
159
+ return rewriter.notifyMatchFailure (
160
+ transposeOp, " transpose perms sizes must be positive" );
140
161
141
- DenseIntElementsAttr permAttr;
142
- if (! matchPattern (perm, m_Constant (&permAttr))) {
143
- return failure ();
144
- }
162
+ // Consolidate transposes into one transpose.
163
+ SmallVector< int32_t > perms (transposePerms. size ());
164
+ for ( int i = 0 , s = transposePerms. size (); i < s; ++i)
165
+ perms[i] = innerTransposePerms[transposePerms[i]];
145
166
146
- SmallVector<int64_t > permValues = llvm::to_vector<6 >(
147
- llvm::map_range (permAttr.getValues <APInt>(),
148
- [](const APInt &val) { return val.getSExtValue (); }));
167
+ auto permsTy =
168
+ RankedTensorType::get (transposePerms.size (), rewriter.getI32Type ());
169
+ auto permsAttr = DenseIntElementsAttr::get (permsTy, perms);
170
+ Value permsValue =
171
+ rewriter.create <arith::ConstantOp>(transposeOp.getLoc (), permsAttr);
149
172
150
- for (int i = 0 , s = permValues.size (); i < s; i++) {
151
- if (i != permValues[i]) {
152
- return failure ();
153
- }
154
- }
173
+ rewriter.replaceOpWithNewOp <tosa::TransposeOp>(
174
+ transposeOp, transposeOp.getResult ().getType (),
175
+ innerTranspose.getInput1 (), permsValue);
155
176
156
- rewriter.replaceOp (op, op.getInput1 ());
157
177
return success ();
158
178
}
159
179
};
@@ -212,7 +232,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
212
232
213
233
void TransposeOp::getCanonicalizationPatterns (RewritePatternSet &results,
214
234
MLIRContext *context) {
215
- results.add <TransposeNoOp , TransposeIsReshape>(context);
235
+ results.add <ConsolidateTransposeOptimization , TransposeIsReshape>(context);
216
236
}
217
237
218
238
struct AddZeroOptimization : public OpRewritePattern <tosa::AddOp> {
@@ -997,26 +1017,27 @@ OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
997
1017
}
998
1018
999
1019
OpFoldResult TransposeOp::fold (ArrayRef<Attribute> operands) {
1000
- if (!operands[1 ])
1001
- return {};
1002
-
1003
1020
auto inputTy = getInput1 ().getType ().cast <ShapedType>();
1004
1021
auto resultTy = getType ().cast <ShapedType>();
1005
- if (inputTy.getElementType () != resultTy.getElementType ())
1006
- return {};
1007
1022
1008
1023
// Transposing splat values just means reshaping.
1009
1024
if (auto input = operands[0 ].dyn_cast_or_null <DenseElementsAttr>()) {
1010
- if (input.isSplat ())
1011
- return input.reshape (getType ().cast <ShapedType>());
1025
+ if (input.isSplat () && resultTy.hasStaticShape () &&
1026
+ inputTy.getElementType () == resultTy.getElementType ())
1027
+ return input.reshape (resultTy);
1012
1028
}
1013
1029
1014
- auto perms = llvm::to_vector< 6 >( llvm::map_range (
1015
- operands[ 1 ]. cast <DenseIntElementsAttr>(). getValues <APInt>(),
1016
- []( const APInt &val) { return val. getSExtValue (); })) ;
1030
+ // Transpose does not change the input type.
1031
+ if ( getInput1 (). getType () != getType ())
1032
+ return {} ;
1017
1033
1018
- if (llvm::equal (llvm::seq<int64_t >(0 , perms.size ()), perms) &&
1019
- getInput1 ().getType () == getType ())
1020
- return getInput1 ();
1021
- return {};
1034
+ // Transpose is not the identity transpose.
1035
+ SmallVector<int64_t > perms;
1036
+ if (getConstantPerms (perms).failed ())
1037
+ return {};
1038
+
1039
+ if (!llvm::equal (llvm::seq<int64_t >(0 , perms.size ()), perms))
1040
+ return {};
1041
+
1042
+ return getInput1 ();
1022
1043
}
0 commit comments