@@ -4229,6 +4229,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4229
4229
metadata.outerDimsPerm );
4230
4230
}
4231
4231
4232
+ // / Returns true if the `srcShape` or `destShape` is different from the one in
4233
+ // / `op` and populates each with the inferred static shape.
4234
+ static bool inferStaticShape (UnPackOp op, SmallVectorImpl<int64_t > &srcShape,
4235
+ SmallVectorImpl<int64_t > &destShape) {
4236
+ bool changeNeeded = false ;
4237
+ srcShape.assign (op.getSourceType ().getShape ().begin (),
4238
+ op.getSourceType ().getShape ().end ());
4239
+ destShape.assign (op.getDestType ().getShape ().begin (),
4240
+ op.getDestType ().getShape ().end ());
4241
+ llvm::SmallSetVector<int64_t , 4 > innerDims;
4242
+ innerDims.insert (op.getInnerDimsPos ().begin (), op.getInnerDimsPos ().end ());
4243
+ auto outerDimsPerm = op.getOuterDimsPerm ();
4244
+ int destRank = op.getDestRank ();
4245
+ for (auto i : llvm::seq<int64_t >(0 , destRank)) {
4246
+ if (innerDims.contains (i))
4247
+ continue ;
4248
+ int64_t srcPos = i;
4249
+ int64_t destPos = i;
4250
+ if (!outerDimsPerm.empty ())
4251
+ srcPos = outerDimsPerm[destPos];
4252
+ if (ShapedType::isDynamic (srcShape[srcPos]) ==
4253
+ ShapedType::isDynamic (destShape[destPos])) {
4254
+ continue ;
4255
+ }
4256
+ int64_t size = srcShape[srcPos];
4257
+ if (ShapedType::isDynamic (size))
4258
+ size = destShape[destPos];
4259
+ srcShape[srcPos] = size;
4260
+ destShape[destPos] = size;
4261
+ changeNeeded = true ;
4262
+ }
4263
+ return changeNeeded;
4264
+ }
4265
+
4232
4266
LogicalResult UnPackOp::canonicalize (UnPackOp unPackOp,
4233
4267
PatternRewriter &rewriter) {
4234
4268
// / pack(unpack(x)) -> x
@@ -4251,6 +4285,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4251
4285
[&]() { unPackOp.setDpsInitOperand (0 , newDest); });
4252
4286
return success ();
4253
4287
}
4288
+
4289
+ // Insert tensor.cast ops if static shape inference is available..
4290
+ SmallVector<int64_t > srcShape, destShape;
4291
+ if (inferStaticShape (unPackOp, srcShape, destShape)) {
4292
+ Location loc = unPackOp.getLoc ();
4293
+ Value source = unPackOp.getSource ();
4294
+ if (srcShape != unPackOp.getSourceType ().getShape ()) {
4295
+ auto newSrcType = unPackOp.getSourceType ().clone (srcShape);
4296
+ source = rewriter.create <tensor::CastOp>(loc, newSrcType,
4297
+ unPackOp.getSource ());
4298
+ }
4299
+ Value dest = unPackOp.getDest ();
4300
+ if (destShape != unPackOp.getDestType ().getShape ()) {
4301
+ auto newDestType = unPackOp.getDestType ().clone (destShape);
4302
+ dest =
4303
+ rewriter.create <tensor::CastOp>(loc, newDestType, unPackOp.getDest ());
4304
+ }
4305
+ Value newOp = rewriter.create <tensor::UnPackOp>(
4306
+ loc, source, dest, unPackOp.getInnerDimsPos (), unPackOp.getMixedTiles (),
4307
+ unPackOp.getOuterDimsPerm ());
4308
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
4309
+ unPackOp, unPackOp.getResult ().getType (), newOp);
4310
+ return success ();
4311
+ }
4312
+
4254
4313
return failure ();
4255
4314
}
4256
4315
0 commit comments