@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
3983
3983
op.getMixedTiles ());
3984
3984
}
3985
3985
3986
+ // / Returns true if the `srcShape` or `destShape` is different from the one in
3987
+ // / `packOp` and populates each with the inferred static shape.
3988
+ static bool inferStaticShape (PackOp packOp, SmallVectorImpl<int64_t > &srcShape,
3989
+ SmallVectorImpl<int64_t > &destShape) {
3990
+ bool changeNeeded = false ;
3991
+ srcShape.assign (packOp.getSourceType ().getShape ().begin (),
3992
+ packOp.getSourceType ().getShape ().end ());
3993
+ destShape.assign (packOp.getDestType ().getShape ().begin (),
3994
+ packOp.getDestType ().getShape ().end ());
3995
+ llvm::SmallSetVector<int64_t , 4 > innerDims;
3996
+ innerDims.insert (packOp.getInnerDimsPos ().begin (),
3997
+ packOp.getInnerDimsPos ().end ());
3998
+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
3999
+ int srcRank = packOp.getSourceRank ();
4000
+ for (auto i : llvm::seq<int64_t >(0 , srcRank)) {
4001
+ if (innerDims.contains (i))
4002
+ continue ;
4003
+ int64_t srcPos = i;
4004
+ int64_t destPos = i;
4005
+ if (!outerDimsPerm.empty ())
4006
+ destPos = outerDimsPerm[srcPos];
4007
+ if (ShapedType::isDynamic (srcShape[srcPos]) ==
4008
+ ShapedType::isDynamic (destShape[destPos])) {
4009
+ continue ;
4010
+ }
4011
+ int64_t size = srcShape[srcPos];
4012
+ if (ShapedType::isDynamic (size))
4013
+ size = destShape[destPos];
4014
+ srcShape[srcPos] = size;
4015
+ destShape[destPos] = size;
4016
+ changeNeeded = true ;
4017
+ }
4018
+ return changeNeeded;
4019
+ }
4020
+
3986
4021
LogicalResult PackOp::canonicalize (PackOp packOp, PatternRewriter &rewriter) {
3987
4022
// Fold an unpack(pack(x)) to x.
3988
4023
if (auto unPackOp = packOp.getSource ().getDefiningOp <UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4003
4038
rewriter.finalizeOpModification (packOp);
4004
4039
return success ();
4005
4040
}
4041
+
4042
+ // Insert tensor.cast ops if static shape inference is available..
4043
+ SmallVector<int64_t > srcShape, destShape;
4044
+ if (inferStaticShape (packOp, srcShape, destShape)) {
4045
+ Location loc = packOp.getLoc ();
4046
+ Value source = packOp.getSource ();
4047
+ if (srcShape != packOp.getSourceType ().getShape ()) {
4048
+ auto newSrcType = packOp.getSourceType ().clone (srcShape);
4049
+ source =
4050
+ rewriter.create <tensor::CastOp>(loc, newSrcType, packOp.getSource ());
4051
+ }
4052
+ Value dest = packOp.getDest ();
4053
+ if (destShape != packOp.getDestType ().getShape ()) {
4054
+ auto newDestType = packOp.getDestType ().clone (destShape);
4055
+ dest =
4056
+ rewriter.create <tensor::CastOp>(loc, newDestType, packOp.getDest ());
4057
+ }
4058
+ Value newOp = rewriter.create <tensor::PackOp>(
4059
+ loc, source, dest, packOp.getInnerDimsPos (), packOp.getMixedTiles (),
4060
+ packOp.getPaddingValue (), packOp.getOuterDimsPerm ());
4061
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
4062
+ packOp, packOp.getResult ().getType (), newOp);
4063
+ return success ();
4064
+ }
4065
+
4006
4066
return failure ();
4007
4067
}
4008
4068
0 commit comments