@@ -1028,9 +1028,8 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
1028
1028
// / This method assumes that all outer dims for this pack Op are 1.
1029
1029
// /
1030
1030
// / At most _one_ inner tile size can be _dynamic_, all other inner tiles are
1031
- // / required to have static sizes. The inner tile that's dynamic must be a
1032
- // / multiple of vector.vscale (to support scalable tile sizes). This condition
1033
- // / can be relaxed in the future.
1031
+ // / required to have static sizes. This restriction can be relaxed in the
1032
+ // / future.
1034
1033
static Value getPackOpSourceOrPaddedSource (OpBuilder &builder,
1035
1034
tensor::PackOp packOp) {
1036
1035
Value input = packOp.getSource ();
@@ -1049,8 +1048,8 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1049
1048
DenseMap<int64_t , OpFoldResult> tileAndPosMapping =
1050
1049
packOp.getDimAndTileMapping ();
1051
1050
1052
- // The size of a scalable tile (if present).
1053
- Value scalableSize ;
1051
+ // The size of a dynamic tile (if present).
1052
+ Value dynamicTileSize ;
1054
1053
1055
1054
// Collect dims for the padded shape.
1056
1055
SmallVector<int64_t > paddedShape;
@@ -1080,16 +1079,15 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1080
1079
// 2.2 Dynamic tile sizes
1081
1080
paddedShape.push_back (ShapedType::kDynamic );
1082
1081
1083
- // Get the value that holds the scalable size.
1084
- assert (!scalableSize && " Only one scalable size is supported ATM." );
1085
- scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
1086
- assert (vector::getConstantVscaleMultiplier (scalableSize) &&
1087
- " This dynamic shape is not a multiple of vscale, this !" );
1082
+ // Get the value that holds the dynamic size.
1083
+ assert (!dynamicTileSize && " Only one dynamic tile is supported ATM." );
1084
+ dynamicTileSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
1088
1085
}
1089
1086
auto resultType =
1090
1087
RankedTensorType::get (paddedShape, inputType.getElementType ());
1091
1088
return tensor::createPadHighOp (resultType, input, packOp.getPaddingValue (),
1092
- /* nofold=*/ false , loc, builder, scalableSize);
1089
+ /* nofold=*/ false , loc, builder,
1090
+ dynamicTileSize);
1093
1091
}
1094
1092
1095
1093
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1152,14 +1150,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1152
1150
1153
1151
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite (
1154
1152
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1155
- if (llvm::any_of (packOp.getMixedTiles (), [](OpFoldResult tile) {
1156
- return tile.is <Value>() && !vector::getConstantVscaleMultiplier (
1157
- llvm::dyn_cast<Value>(tile));
1158
- })) {
1159
- return rewriter.notifyMatchFailure (
1160
- packOp, " require inner tile sizes to be either static or a constant "
1161
- " multiple of vector.vscale" );
1162
- }
1163
1153
if (llvm::count_if (packOp.getMixedTiles (),
1164
1154
[](OpFoldResult tile) { return tile.is <Value>(); }) > 1 ) {
1165
1155
return rewriter.notifyMatchFailure (
@@ -1221,22 +1211,20 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1221
1211
SmallVector<int64_t > transpShape = readShape;
1222
1212
applyPermutationToVector<int64_t >(transpShape, perm);
1223
1213
1224
- // If there's a tile with a scalable size, retrieve its size. ATM only 1
1225
- // scalable tile is allowed.
1226
- Value scalableSize ;
1214
+ // If there's a tile with a dynamic size, retrieve its size. ATM only 1
1215
+ // dynamic tile is allowed.
1216
+ Value dynDimSize ;
1227
1217
for (auto tile : packOp.getMixedTiles ()) {
1228
1218
if (tile.is <Value>()) {
1229
- assert (!scalableSize && " Only one scalable size is supported ATM." );
1230
- scalableSize = cast<Value>(tile);
1231
- assert (vector::getConstantVscaleMultiplier (scalableSize) &&
1232
- " This dynamic shape is not a multiple of vscale!" );
1219
+ assert (!dynDimSize && " Only one scalable size is supported ATM." );
1220
+ dynDimSize = cast<Value>(tile);
1233
1221
}
1234
1222
}
1235
1223
1236
1224
Value empty =
1237
1225
ShapedType::isDynamicShape (cast<ShapedType>(input.getType ()).getShape ())
1238
1226
? rewriter.create <tensor::EmptyOp>(loc, transpShape, elemType,
1239
- scalableSize )
1227
+ dynDimSize )
1240
1228
: rewriter.create <tensor::EmptyOp>(loc, transpShape, elemType);
1241
1229
auto transposedOp =
1242
1230
rewriter.create <linalg::TransposeOp>(loc, tile, empty, perm);
0 commit comments