@@ -1177,12 +1177,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1177
1177
SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
1178
1178
SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
1179
1179
SmallVector<OpFoldResult> readSizes;
1180
- SmallVector<int64_t > readShape;
1180
+ SmallVector<OpFoldResult> transShapeForEmpty;
1181
+ SmallVector<int64_t > readShapeForExtractSlice;
1181
1182
for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1182
1183
if (dimAndTileMapping.count (i)) {
1183
- readShape.push_back (getConstantIntValue (dimAndTileMapping[i])
1184
- .value_or (ShapedType::kDynamic ));
1184
+ readShapeForExtractSlice.push_back (
1185
+ getConstantIntValue (dimAndTileMapping[i])
1186
+ .value_or (ShapedType::kDynamic ));
1185
1187
readSizes.push_back (dimAndTileMapping[i]);
1188
+ transShapeForEmpty.push_back (dimAndTileMapping[i]);
1186
1189
continue ;
1187
1190
}
1188
1191
if (ShapedType::isDynamic (inputShape[i])) {
@@ -1191,12 +1194,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1191
1194
} else {
1192
1195
readSizes.push_back (rewriter.getIndexAttr (inputShape[i]));
1193
1196
}
1194
- if (inputShape[i] != 1 )
1195
- readShape.push_back (inputShape[i]);
1197
+ if (inputShape[i] != 1 ) {
1198
+ readShapeForExtractSlice.push_back (inputShape[i]);
1199
+ transShapeForEmpty.push_back (rewriter.getIndexAttr (inputShape[i]));
1200
+ }
1196
1201
}
1197
1202
1198
1203
Type elemType = packOp.getSourceType ().getElementType ();
1199
- auto readType = RankedTensorType::get (readShape , elemType);
1204
+ auto readType = RankedTensorType::get (readShapeForExtractSlice , elemType);
1200
1205
1201
1206
Value tile = rewriter.create <tensor::ExtractSliceOp>(
1202
1207
loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1208,8 +1213,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1208
1213
LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
1209
1214
llvm::interleaveComma (perm, DBGS () << " perm: " ); DBGSNL (););
1210
1215
1211
- SmallVector<int64_t > transpShape = readShape;
1212
- applyPermutationToVector<int64_t >(transpShape, perm);
1216
+ applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1213
1217
1214
1218
// If there's a tile with a dynamic size, retrieve its size. ATM only 1
1215
1219
// dynamic tile is allowed.
@@ -1222,10 +1226,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1222
1226
}
1223
1227
1224
1228
Value empty =
1225
- ShapedType::isDynamicShape (cast<ShapedType>(input.getType ()).getShape ())
1226
- ? rewriter.create <tensor::EmptyOp>(loc, transpShape, elemType,
1227
- dynDimSize)
1228
- : rewriter.create <tensor::EmptyOp>(loc, transpShape, elemType);
1229
+ rewriter.create <tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1229
1230
auto transposedOp =
1230
1231
rewriter.create <linalg::TransposeOp>(loc, tile, empty, perm);
1231
1232
0 commit comments