Skip to content

Commit 67da598

Browse files
committed
Folding unpack and pack sequence
1 parent 1101b76 commit 67da598

File tree

2 files changed

+69
-35
lines changed

2 files changed

+69
-35
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
298298
return std::make_tuple(packedOperand, indexingMap);
299299
}
300300

301+
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
302+
int numDpsOuts = genericOp.getNumDpsInits();
303+
for (int i = 0; i < numDpsOuts; ++i) {
304+
Block *block = genericOp.getBody();
305+
int numBlockArgs = block->getNumArguments();
306+
int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
307+
return block->getArgument(matchingInitArgIndex).use_empty();
308+
}
309+
return true;
310+
}
311+
301312
/// Pack a genericOp and return it.
302313
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303314
Value dest, AffineMap packedOutIndexingMap,
304315
const PackInfo &packInfo) {
305316
Location loc = genericOp.getLoc();
306317
SmallVector<Value> inputOperands;
318+
SmallVector<Value> inputOperandsFromUnpackedSource;
307319
SmallVector<AffineMap> indexingMaps;
320+
321+
// Note: canUnpackPackFold needs to also guarantee the generic body
322+
// doesn't have gather semantics. Since such scenarios has been
323+
// rejected by both BubbleUpPackOpThroughGenericOp and
324+
// PushDownUnPackOpThroughGenericOp, we can safely assume
325+
// canUnpackPackFold is as long as init is not used.
326+
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
308327
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309328
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310329
rewriter, loc, packInfo, genericOp, inputOperand);
330+
331+
if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
332+
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
333+
} else {
334+
inputOperandsFromUnpackedSource.push_back(packedOperand);
335+
}
336+
311337
inputOperands.push_back(packedOperand);
312338
indexingMaps.push_back(packedIndexingMap);
313339
}
314340

341+
// If The pack and unpack op can be folded:
342+
// 1) use unpack op source op for operand to fold unpack -> pack sequence
343+
// 2) init tensor of the generic op can be replaced by the new tensor.empty
344+
// as the generic out.
345+
if (canUnpackPackFold) {
346+
inputOperands = inputOperandsFromUnpackedSource;
347+
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
348+
dest = destPack.getDest();
349+
}
350+
315351
int64_t numInnerLoops = packInfo.getNumTiledLoops();
316352
SmallVector<utils::IteratorType> iterTypes =
317353
genericOp.getIteratorTypesArray();

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
524524
// CHECK-LABEL: func.func @unpack_element_type_change
525525
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
526526
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
527-
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
528-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
529-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
530-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
531-
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
532-
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
533-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
534-
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
535-
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
536-
// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
537-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
538-
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
527+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
539528
// CHECK: %[[RES:.+]] = linalg.generic
540529
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
541-
// CHECK-SAME: ins(%[[ARG0_PACK]]
542-
// CHECK-SAME: outs(%[[ARG1_PACK]]
530+
// CHECK-SAME: ins(%[[ARG0]]
531+
// CHECK-SAME: outs(%[[EMPTY]]
543532
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
544533
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545534
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
564553
// CHECK-LABEL: func.func @forward_tensor_empty
565554
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
566555
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
567-
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
568-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
569-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
570-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
571-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
572-
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
573-
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
574-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
575-
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
556+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
576557
// CHECK: %[[RES:.+]] = linalg.generic
577558
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
578-
// CHECK-SAME: ins(%[[PACKED_ARG0]]
579-
// CHECK-SAME: outs(%[[DEST]]
559+
// CHECK-SAME: ins(%[[ARG0]]
560+
// CHECK-SAME: outs(%[[EMPTY]]
580561
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
581562
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
582563
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
810791
}
811792

812793
// CHECK-LABEL: func.func @unpack_empty_inner_dims
813-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
814-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
815-
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
816-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
794+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
817795
// CHECK: %[[RES:.+]] = linalg.generic
818-
// CHECK-SAME: ins(%[[PACKED_ARG0]]
796+
// CHECK-SAME: ins(%[[ARG0]]
819797
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
820798
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
821799

@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
943921
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
944922
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
945923
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
946-
// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
947-
// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
948-
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
949-
// CHECK-SAME: into %[[PACK_EMPTY]]
950924
// CHECK: %[[POOL:.+]] = linalg.generic
951925
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
952926
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
953-
// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
927+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
954928
// CHECK-SAME: outs(%[[INIT]]
955929
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
956930
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
14211395
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
14221396
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
14231397
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
1398+
1399+
// -----
1400+
1401+
#map = affine_map<(d0, d1) -> (d0, d1)>
1402+
func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
1403+
%empty = tensor.empty() : tensor<32x64xf32>
1404+
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
1405+
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
1406+
^bb0(%in: f32, %out: f32):
1407+
%2 = arith.addf %in, %in : f32
1408+
linalg.yield %2 : f32
1409+
} -> tensor<32x64xf32>
1410+
%empty1 = tensor.empty() : tensor<8x8x4x8xf32>
1411+
%pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
1412+
return %pack : tensor<8x8x4x8xf32>
1413+
}
1414+
1415+
// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
1416+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1417+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
1418+
// CHECK: %[[GENERIC:.+]] = linalg.generic
1419+
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1420+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
1421+
// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>

0 commit comments

Comments
 (0)