@@ -1139,37 +1139,14 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1139
1139
return perm;
1140
1140
}
1141
1141
1142
- // A helper function to generate a dim-and-size pair for Ops like
1143
- // ExtractSliceOp that require both:
1144
- // * dims to specify the output shape, and
1145
- // * sizes for the sizes attribute (or similar).
1146
- // For dynamic sizes, if the corresponding size is a compile time constant:
1147
- // * the return size becomes the attribute encapsulating the known size, and
1148
- // * dim is updated from kDynamic to its actual known value.
1149
- static std::pair<int64_t , OpFoldResult>
1150
- getSimplifiedDimSizePair (OpFoldResult tileSizeOfr, Builder &b) {
1151
- int64_t tileSizeForShape =
1152
- getConstantIntValue (tileSizeOfr).value_or (ShapedType::kDynamic );
1153
-
1154
- OpFoldResult tileSizeOfrSimplified;
1155
- if (tileSizeForShape != ShapedType::kDynamic ) {
1156
- tileSizeOfrSimplified = b.getIndexAttr (tileSizeForShape);
1157
- } else {
1158
- tileSizeOfrSimplified = tileSizeOfr;
1159
- }
1160
-
1161
- return std::pair<int64_t , OpFoldResult>(tileSizeForShape,
1162
- tileSizeOfrSimplified);
1163
- }
1164
-
1165
1142
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite (
1166
1143
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1167
1144
// TODO: support the case that outer dimensions are not all 1s. A
1168
1145
// tensor.expand_shape will be generated in this case.
1169
- if (llvm::any_of (packOp.getTiledOuterDims (),
1146
+ if (llvm::any_of (packOp.getAllOuterDims (),
1170
1147
[](int64_t dim) { return dim != 1 ; })) {
1171
1148
return rewriter.notifyMatchFailure (
1172
- packOp, " require the tiled outer dimensions of the result are all 1s" );
1149
+ packOp, " not all outer dimensions of the result are 1s" );
1173
1150
}
1174
1151
1175
1152
Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
@@ -1202,7 +1179,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1202
1179
for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1203
1180
if (dimAndTileMapping.count (i)) {
1204
1181
auto [tileSize, tileSizeOfr] =
1205
- getSimplifiedDimSizePair (dimAndTileMapping[i], rewriter);
1182
+ getSimplifiedOfrAndStaticSizePair (dimAndTileMapping[i], rewriter);
1206
1183
extractSliceSizes.push_back (tileSizeOfr);
1207
1184
outputShapeForExtractSlice.push_back (tileSize);
1208
1185
}
@@ -1254,7 +1231,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1254
1231
1255
1232
for (auto tileSize : packOp.getMixedTiles ()) {
1256
1233
auto [tileSizeStatic, tileSizeOfr] =
1257
- getSimplifiedDimSizePair (tileSize, rewriter);
1234
+ getSimplifiedOfrAndStaticSizePair (tileSize, rewriter);
1258
1235
writeSizes.push_back (tileSizeOfr);
1259
1236
writeShape.push_back (tileSizeStatic);
1260
1237
}
0 commit comments