@@ -345,12 +345,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
345
345
}
346
346
347
347
// 4. Expand from the padded result to the stripMinedShape.
348
- // Check if any dims are not factorable. A dim is factorable if the expansion
349
- // requires at most dynamnic dim
350
- RankedTensorType expandDestType = RankedTensorType::Builder (packedTensorType).setShape (stripMinedShape);
348
+ RankedTensorType expandDestType =
349
+ RankedTensorType::Builder (packedTensorType).setShape (stripMinedShape);
351
350
SmallVector<int64_t > transpPerm =
352
351
invertPermutationVector (packedToStripMinedShapePerm);
353
352
Operation *reshapeOp;
353
+ // Check if any dims are not factorable and thus need a `tensor.reshape`
354
+ // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
355
+ // requires at most dynamnic dim
354
356
if (llvm::any_of (packingMetadata.reassociations ,
355
357
[&](const auto &rAssoc) -> bool {
356
358
return llvm::count_if (rAssoc, [&](int64_t r) {
@@ -360,7 +362,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
360
362
SmallVector<OpFoldResult> sizes =
361
363
tensor::getMixedSizes (rewriter, loc, packOp.getDest ());
362
364
applyPermutationToVector (sizes, transpPerm);
363
- // Create a `tensor` of `index` types for the `shape` operand of `tensor.reshape`
365
+ // Create a `tensor` of `index` types for the `shape` operand of
366
+ // `tensor.reshape`
364
367
Value shapeInitTensor = rewriter.create <tensor::EmptyOp>(
365
368
loc,
366
369
RankedTensorType::get ({expandDestType.getRank ()},
0 commit comments